package nmagent_test import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "strings" "testing" "github.com/Azure/azure-container-networking/nmagent" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" ) var _ http.RoundTripper = &TestTripper{} // TestTripper is a RoundTripper with a customizeable RoundTrip method for // testing purposes type TestTripper struct { RoundTripF func(*http.Request) (*http.Response, error) } func (t *TestTripper) RoundTrip(req *http.Request) (*http.Response, error) { return t.RoundTripF(req) } func TestNMAgentClientJoinNetwork(t *testing.T) { joinNetTests := []struct { name string id string exp string respStatus int shouldErr bool }{ { "happy path", "00000000-0000-0000-0000-000000000000", "/machine/plugins?comp=nmagent&type=NetworkManagement%2FjoinedVirtualNetworks%2F00000000-0000-0000-0000-000000000000%2Fapi-version%2F1", http.StatusOK, false, }, { "empty network ID", "", "", http.StatusOK, // this shouldn't be checked true, }, { "internal error", "00000000-0000-0000-0000-000000000000", "/machine/plugins?comp=nmagent&type=NetworkManagement%2FjoinedVirtualNetworks%2F00000000-0000-0000-0000-000000000000%2Fapi-version%2F1", http.StatusInternalServerError, true, }, } for _, test := range joinNetTests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() // create a client var got string client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { got = req.URL.RequestURI() rr := httptest.NewRecorder() _, _ = fmt.Fprintf(rr, `{"httpStatusCode":"%d"}`, test.respStatus) rr.WriteHeader(http.StatusOK) return rr.Result(), nil }, }) ctx, cancel := testContext(t) defer cancel() // attempt to join network err := client.JoinNetwork(ctx, nmagent.JoinNetworkRequest{test.id}) checkErr(t, err, test.shouldErr) if got != test.exp { t.Error("received URL differs from expectation: got", got, "exp:", test.exp) } }) } } func TestNMAgentClientJoinNetworkRetry(t *testing.T) { // we want to ensure that the client will automatically follow up with // NMAgent, so we want to track the number of requests that it makes invocations := 0 exp := 10 client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(_ *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() if invocations < exp { rr.WriteHeader(http.StatusProcessing) invocations++ } else { rr.WriteHeader(http.StatusOK) } _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) return rr.Result(), nil }, }) ctx, cancel := testContext(t) defer cancel() // attempt to join network err := client.JoinNetwork(ctx, nmagent.JoinNetworkRequest{"00000000-0000-0000-0000-000000000000"}) if err != nil { t.Fatal("unexpected error: err:", err) } if invocations != exp { t.Error("client did not make the expected number of API calls: got:", invocations, "exp:", exp) } } func TestNMAgentClientDeleteNetwork(t *testing.T) { deleteNetTests := []struct { name string id string exp string respStatus int shouldErr bool shouldNotFound bool }{ { "happy path", "00000000-0000-0000-0000-000000000000", "/machine/plugins?comp=nmagent&type=NetworkManagement%2FjoinedVirtualNetworks%2F00000000-0000-0000-0000-000000000000%2Fapi-version%2F1%2Fmethod%2FDELETE", http.StatusOK, false, false, }, { "empty network ID", "", "", http.StatusOK, // this shouldn't be checked true, false, }, { "internal error", "00000000-0000-0000-0000-000000000000", "/machine/plugins?comp=nmagent&type=NetworkManagement%2FjoinedVirtualNetworks%2F00000000-0000-0000-0000-000000000000%2Fapi-version%2F1%2Fmethod%2FDELETE", http.StatusInternalServerError, true, false, }, { "network does not exist", "00000000-0000-0000-0000-000000000000", "/machine/plugins?comp=nmagent&type=NetworkManagement%2FjoinedVirtualNetworks%2F00000000-0000-0000-0000-000000000000%2Fapi-version%2F1%2Fmethod%2FDELETE", http.StatusBadRequest, true, true, }, } for _, test := range deleteNetTests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() // create a client var got string client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { got = req.URL.RequestURI() rr := httptest.NewRecorder() _, _ = fmt.Fprintf(rr, `{"httpStatusCode":"%d"}`, test.respStatus) rr.WriteHeader(http.StatusOK) return rr.Result(), nil }, }) ctx, cancel := testContext(t) defer cancel() // attempt to delete network err := client.DeleteNetwork(ctx, nmagent.DeleteNetworkRequest{test.id}) checkErr(t, err, test.shouldErr) var nmaError nmagent.Error errors.As(err, &nmaError) if nmaError.NotFound() != test.shouldNotFound { t.Error("unexpected NotFound value: got:", nmaError.NotFound(), "exp:", test.shouldNotFound) } if got != test.exp { t.Error("received URL differs from expectation: got", got, "exp:", test.exp) } }) } } func TestWSError(t *testing.T) { const wsError string = ` InternalError The server encountered an internal error. Please retry the request.
` client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(_ *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() rr.WriteHeader(http.StatusInternalServerError) _, _ = rr.WriteString(wsError) return rr.Result(), nil }, }) req := nmagent.GetNetworkConfigRequest{ VNetID: "4815162342", } ctx, cancel := testContext(t) defer cancel() _, err := client.GetNetworkConfiguration(ctx, req) if err == nil { t.Fatal("expected error to not be nil") } var cerr nmagent.Error ok := errors.As(err, &cerr) if !ok { t.Fatal("error was not an nmagent.Error") } t.Log(cerr.Error()) if !strings.Contains(cerr.Error(), "InternalError") { t.Error("error did not contain the error content from wireserver") } } func TestNMAgentGetNetworkConfig(t *testing.T) { getTests := []struct { name string vnetID string expURL string expVNet map[string]interface{} shouldCall bool shouldErr bool }{ { "happy path", "00000000-0000-0000-0000-000000000000", "/machine/plugins?comp=nmagent&type=NetworkManagement%2FjoinedVirtualNetworks%2F00000000-0000-0000-0000-000000000000%2Fapi-version%2F1", map[string]interface{}{ "httpStatusCode": "200", "cnetSpace": "10.10.1.0/24", "defaultGateway": "10.10.0.1", "dnsServers": []string{ "1.1.1.1", "1.0.0.1", }, "subnets": []map[string]interface{}{}, "vnetSpace": "10.0.0.0/8", "vnetVersion": "12345", }, true, false, }, } for _, test := range getTests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() var got string client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() got = req.URL.RequestURI() rr.WriteHeader(http.StatusOK) err := json.NewEncoder(rr).Encode(&test.expVNet) if err != nil { return nil, errors.Wrap(err, "encoding response") } return rr.Result(), nil }, }) ctx, cancel := testContext(t) defer cancel() gotVNet, err := client.GetNetworkConfiguration(ctx, nmagent.GetNetworkConfigRequest{test.vnetID}) checkErr(t, err, test.shouldErr) if got != test.expURL && test.shouldCall { t.Error("unexpected URL: got:", got, "exp:", test.expURL) } // TODO(timraymond): this is ugly expVnet := nmagent.VirtualNetwork{ CNetSpace: test.expVNet["cnetSpace"].(string), DefaultGateway: test.expVNet["defaultGateway"].(string), DNSServers: test.expVNet["dnsServers"].([]string), Subnets: []nmagent.Subnet{}, VNetSpace: test.expVNet["vnetSpace"].(string), VNetVersion: test.expVNet["vnetVersion"].(string), } if !cmp.Equal(gotVNet, expVnet) { t.Error("received vnet differs from expected: diff:", cmp.Diff(gotVNet, expVnet)) } }) } } func TestNMAgentGetNetworkConfigRetry(t *testing.T) { t.Parallel() count := 0 exp := 10 client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(_ *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() if count < exp { rr.WriteHeader(http.StatusProcessing) count++ } else { rr.WriteHeader(http.StatusOK) } // we still need a fake response _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) return rr.Result(), nil }, }) ctx, cancel := testContext(t) defer cancel() _, err := client.GetNetworkConfiguration(ctx, nmagent.GetNetworkConfigRequest{"00000000-0000-0000-0000-000000000000"}) if err != nil { t.Fatal("unexpected error: err:", err) } if count != exp { t.Error("unexpected number of API calls: exp:", exp, "got:", count) } } func TestNMAgentPutNetworkContainer(t *testing.T) { putNCTests := []struct { name string req *nmagent.PutNetworkContainerRequest shouldCall bool shouldErr bool }{ { "happy path", &nmagent.PutNetworkContainerRequest{ ID: "350f1e3c-4283-4f51-83a1-c44253962ef1", Version: uint64(12345), VNetID: "be3a33e-61e3-42c7-bd23-6b949f57bd36", SubnetName: "TestSubnet", IPv4Addrs: []string{"10.0.0.43"}, Policies: []nmagent.Policy{ { ID: "policyID1", Type: "type1", }, { ID: "policyID2", Type: "type2", }, }, VlanID: 1234, AuthenticationToken: "swordfish", PrimaryAddress: "10.0.0.1", }, true, false, }, { "no id", &nmagent.PutNetworkContainerRequest{ Version: uint64(12345), VNetID: "be3a33e-61e3-42c7-bd23-6b949f57bd36", SubnetName: "TestSubnet", IPv4Addrs: []string{"10.0.0.43"}, Policies: []nmagent.Policy{ { ID: "policyID1", Type: "type1", }, { ID: "policyID2", Type: "type2", }, }, VlanID: 1234, AuthenticationToken: "swordfish", PrimaryAddress: "10.0.0.1", }, false, true, }, } for _, test := range putNCTests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() didCall := false client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(_ *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) rr.WriteHeader(http.StatusOK) didCall = true return rr.Result(), nil }, }) err := client.PutNetworkContainer(context.TODO(), test.req) if err != nil && !test.shouldErr { t.Fatal("unexpected error: err", err) } if err == nil && test.shouldErr { t.Fatal("expected error but received none") } if test.shouldCall && !didCall { t.Fatal("expected call but received none") } if !test.shouldCall && didCall { t.Fatal("unexpected call. expected no call ") } }) } } func TestNMAgentDeleteNC(t *testing.T) { deleteTests := []struct { name string req nmagent.DeleteContainerRequest exp string shouldErr bool }{ { "happy path", nmagent.DeleteContainerRequest{ NCID: "00000000-0000-0000-0000-000000000000", PrimaryAddress: "10.0.0.1", AuthenticationToken: "swordfish", }, //nolint:lll // not a useful linter in a test "/machine/plugins?comp=nmagent&type=NetworkManagement%2Finterfaces%2F10.0.0.1%2FnetworkContainers%2F00000000-0000-0000-0000-000000000000%2FauthenticationToken%2Fswordfish%2Fapi-version%2F1%2Fmethod%2FDELETE", false, }, } var got string for _, test := range deleteTests { test := test t.Run(test.name, func(t *testing.T) { client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { got = req.URL.RequestURI() rr := httptest.NewRecorder() _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) return rr.Result(), nil }, }) err := client.DeleteNetworkContainer(context.TODO(), test.req) if err != nil && !test.shouldErr { t.Fatal("unexpected error: err:", err) } if err == nil && test.shouldErr { t.Fatal("expected error but received none") } if test.exp != got { t.Errorf("received URL differs from expectation:\n\texp: %q:\n\tgot: %q", test.exp, got) } }) } } func TestNMAgentSupportedAPIs(t *testing.T) { tests := []struct { name string exp []string expPath string resp string shouldErr bool }{ { "empty", nil, "/machine/plugins?comp=nmagent&type=GetSupportedApis", "", false, }, { "happy", []string{"foo"}, "/machine/plugins?comp=nmagent&type=GetSupportedApis", "foo", false, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() var gotPath string client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { gotPath = req.URL.RequestURI() rr := httptest.NewRecorder() _, _ = rr.WriteString(test.resp) return rr.Result(), nil }, }) got, err := client.SupportedAPIs(context.Background()) if err != nil && !test.shouldErr { t.Fatal("unexpected error: err:", err) } if err == nil && test.shouldErr { t.Fatal("expected error but received none") } if gotPath != test.expPath { t.Error("paths differ: got:", gotPath, "exp:", test.expPath) } if !cmp.Equal(got, test.exp) { t.Error("response differs from expectation: diff:", cmp.Diff(got, test.exp)) } }) } } func TestGetNCVersion(t *testing.T) { tests := []struct { name string req nmagent.NCVersionRequest expURL string resp map[string]interface{} shouldErr bool }{ { "empty", nmagent.NCVersionRequest{}, "", map[string]interface{}{}, true, }, { "happy path", nmagent.NCVersionRequest{ AuthToken: "foo", NetworkContainerID: "bar", PrimaryAddress: "baz", }, "/machine/plugins?comp=nmagent&type=NetworkManagement%2Finterfaces%2Fbaz%2FnetworkContainers%2Fbar%2Fversion%2FauthenticationToken%2Ffoo%2Fapi-version%2F1", map[string]interface{}{ "httpStatusCode": "200", "networkContainerId": "bar", "version": "4815162342", }, false, }, { "non-200", nmagent.NCVersionRequest{ AuthToken: "foo", NetworkContainerID: "bar", PrimaryAddress: "baz", }, "/machine/plugins?comp=nmagent&type=NetworkManagement%2Finterfaces%2Fbaz%2FnetworkContainers%2Fbar%2Fversion%2FauthenticationToken%2Ffoo%2Fapi-version%2F1", map[string]interface{}{ "httpStatusCode": "500", }, true, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() var gotURL string client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { gotURL = req.URL.RequestURI() rr := httptest.NewRecorder() err := json.NewEncoder(rr).Encode(test.resp) if err != nil { t.Fatal("unexpected error encoding test response: err:", err) } rr.WriteHeader(http.StatusOK) return rr.Result(), nil }, }) ctx, cancel := testContext(t) defer cancel() got, err := client.GetNCVersion(ctx, test.req) checkErr(t, err, test.shouldErr) if gotURL != test.expURL { t.Error("received URL differs from expected: got:", gotURL, "exp:", test.expURL) } exp := nmagent.NCVersion{} if ncid, ok := test.resp["networkContainerId"]; ok { exp.NetworkContainerID = ncid.(string) } if version, ok := test.resp["version"]; ok { exp.Version = version.(string) } if !cmp.Equal(got, exp) { t.Error("response differs from expectation: diff:", cmp.Diff(got, exp)) } }) } } func TestGetNCVersionList(t *testing.T) { tests := []struct { name string resp map[string]interface{} expURL string exp nmagent.NCVersionList shouldErr bool }{ { "happy path", map[string]interface{}{ "httpStatusCode": "200", "networkContainers": []map[string]interface{}{ { "networkContainerId": "foo", "version": "42", }, }, }, "/machine/plugins?comp=nmagent&type=NetworkManagement%2Finterfaces%2Fapi-version%2F2", nmagent.NCVersionList{ Containers: []nmagent.NCVersion{ { NetworkContainerID: "foo", Version: "42", }, }, }, false, }, { "nma fail", map[string]interface{}{ "httpStatusCode": "500", }, "/machine/plugins?comp=nmagent&type=NetworkManagement%2Finterfaces%2Fapi-version%2F2", nmagent.NCVersionList{}, true, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() var gotURL string client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { gotURL = req.URL.RequestURI() rr := httptest.NewRecorder() rr.WriteHeader(http.StatusOK) err := json.NewEncoder(rr).Encode(test.resp) if err != nil { t.Fatal("unexpected error encoding response: err:", err) } return rr.Result(), nil }, }) ctx, cancel := testContext(t) defer cancel() resp, err := client.GetNCVersionList(ctx) checkErr(t, err, test.shouldErr) if gotURL != test.expURL { t.Error("received URL differs from expected: got:", gotURL, "exp:", test.expURL) } if got := resp; !cmp.Equal(got, test.exp) { t.Error("response differs from expectation: diff:", cmp.Diff(got, test.exp)) } }) } } func TestGetHomeAz(t *testing.T) { tests := []struct { name string exp nmagent.AzResponse expPath string resp map[string]interface{} shouldErr bool }{ { "happy path", nmagent.AzResponse{HomeAz: uint(1)}, "/machine/plugins?comp=nmagent&type=GetHomeAz%2Fapi-version%2F1", map[string]interface{}{ "httpStatusCode": "200", "HomeAz": 1, }, false, }, { "empty response", nmagent.AzResponse{}, "/machine/plugins?comp=nmagent&type=GetHomeAz%2Fapi-version%2F1", map[string]interface{}{ "httpStatusCode": "500", }, true, }, { "404 from NMA", nmagent.AzResponse{}, "/machine/plugins?comp=nmagent&type=GetHomeAz%2Fapi-version%2F1", map[string]interface{}{ "httpStatusCode": "404", }, true, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() err := json.NewEncoder(rr).Encode(test.resp) if err != nil { t.Fatal("unexpected error encoding response: err:", err) } rr.WriteHeader(http.StatusOK) return rr.Result(), nil }, }) got, err := client.GetHomeAz(context.TODO()) if err != nil && !test.shouldErr { t.Fatal("unexpected error: err:", err) } if err == nil && test.shouldErr { t.Fatal("expected error but received none") } if !cmp.Equal(got, test.exp) { t.Error("response differs from expectation: diff:", cmp.Diff(got, test.exp)) } }) } }