Add NMAgent Client (#1305)
* Add implementation for GetNetworkConfiguration Previously the NMAgent client did not have support for the GetNetworkConfiguration API call. This adds it and appropriate coverage. * Refactor retry loops to use shared function The cancellable retry was common enough that it made sense to extract it to a separate BackoffRetry function in internal. This made its functionality easier to test and reduced the number of tests necessary for each new endpoint * Slight re-org The client had enough extra stuff in it that it made sense to start separating things into different files * Add retries for Unauthorized responses In the original logic, unauthorized responses are treated as temporary for a specific period of time. This makes the nmagent.Error consider Unauthorized responses as temporary for a configurable time. Given that BackoffRetry cares only whether or not an error is temporary, this naturally causes them to be retried. Additional coverage was added for these scenarios as well. * Add a WireserverTransport This deals with all the quirks of proxying requests to NMAgent through Wireserver, without spreading that concern through the NMAgent client itself. * Reorganize the nmagent internal package The wireserver transport became big enough to warrant its own file * Use WireserverTransport This required some changes to the test so that the WireserverTransport middleware could take effect always * Add PutNetworkContainer method to NMAgent client This is another API that must be implemented * Switch NMAgent client port to uint16 Ports are uint16s by definition. * Add missing body close and context propagation * Add DeleteNetworkContainer endpoint * Move internal imports to another section It's a bit clearer when internal imports are isolated into one section, standard library imports in another, then finally external imports in another section. * Additional Validation / Retry improvements This is a bit of a rollup commit, including some additional validation logic for some nmagent requests and also some improvements to the internal retry logic. The retry logic is now a struct, and the client depends only on an interface for retrying. This is to accommodate the existing retry package (which was unknown). The internal Retrier was enhanced to add a configurable Cooldown function with strategies for Fixed backoff, Exponential, and a Max limitation. * Move GetNetworkConfig request params to a struct This follows the pattern established in other API calls. It moves validation to the request itself and also leaves the responsibility for constructing paths to the request. * Add Validation and Path to put request To be consistent with the other request types, this adds Validate and Path methods to the PutNetworkContainerRequest * Introduce Request and Option Enough was common among building requests and validating them that it made sense to formalize it as an interface of its own. This allowed centralizing the construction of HTTP requests in the nmagent.Client. As such, it made adding TLS disablement trivial. Since there is some optional behavior that can be configured with the nmagent.Client, nmagent.Option has been introduced to handle this in a clean manner. * Add additional error documentation The NMAgent documentation contains some additional documentation as to the meaning of particular HTTP Status codes. Since we have this information, it makes sense to enhance the nmagent.Error so it can explain what the problem is. * Fix issue with cooldown remembering state Previously, cooldown functions were able to retain state across invocations of the "Do" method of the retrier. This adds an additional layer of functions to allow the Retrier to purge the accumulated state * Move Validation to reflection-based helper The validation logic for each struct was repetitive and it didn't help answer the common question "what fields are required by this request struct." This adds a "validate" struct tag that can be used to annotate fields within the request struct and mark them as required. It's still possible to do arbitrary validation within the Validate method of each request, but the common things like "is this field a zero value?" are abstracted into the internal helper. This also serves as documentation to future readers, making it easier to use the package. * Housekeeping: file renaming nmagent.go was really focused on the nmagent.Error type, so it made sense to rename the file to be more revealing. The same goes for internal.go and internal_test.go. Both of those were focused on retry logic. Also added a quick note explaining why client_helpers_test.go exists, since it can be a little subtle to those new to the language. * Remove Vim fold markers While this is nice for vim users, @ramiro-gamarra rightly pointed out that this is a maintenance burden for non-vim users with little benefit. Removing these to reduce the overhead. * Set default scheme to http for nmagent client In practice, most communication for the nmagent client occurs over HTTP because it is intra-node traffic. While this is a useful option to have, the default should be useful for the common use case. * Change retry functions to return durations It was somewhat limiting that cooldown functions themselves would block. What was really interesting about them is that they calculated some time.Duration. Since passing 0 to time.Sleep causes it to return immediately, this has no impact on the AsFastAsPossible strategy Also improved some documentation and added a few examples at the request of @aegal * Rename imports The imports were incorrect because this client was moved from another module. * Duplicate the request in wireserver transport Upon closer reading of the RoundTripper documentation, it's clear that RoundTrippers should not modify the request. While effort has been made to reset any mutations, this is still, technically, modifying the request. Instead, this duplicates the request immediately and re-uses the context that was provided to it. * Drain and close http ResponseBodies It's not entirely clear whether this is needed or not. The documentation for http.(*Client).Do indicates that this is necessary, but experimentation in the community has found that this is maybe not 100% necessary (calling `Close` on the Body appears to be enough). The only harm that can come from this is if Wireserver hands back enormous responses, which is not the case--these responses are fairly small. * Capture unexpected content from Wireserver During certain error cases, Wireserver may return XML. This XML is useful in debugging, so we want to capture it in the error and surface it appropriately. It's unclear whether Wireserver notes the Content-Type, so we use Go's content type detection to figure out what the type of the response is and clean it up to pass along to the NMAgent Client. This also introduces a new ContentError which semantically represents the situation where we were given a content type that we didn't expect. * Don't return a response with an error in RoundTrip The http.Client complains if you return a non-nil response and an error as well. This fixes one instance where that was happening. * Remove extra vim folding marks These were intended to be removed in another commit, but there were some stragglers. * Replace fmt.Errorf with errors.Wrap Even though fmt.Errorf provides an official error-wrapping solution for Go, we have made the decision to use errors.Wrap for its stack collection support. This integrates well with Uber's Zap logger, which we also plan to integrate. * Use Config struct instead of functional Options We determined that a Config struct would be more obvious than the functional options in a debugging scenario. * Remove validation struct tags The validation struct tags were deemed too magical and thus removed in favor of straight-line validation logic. * Address Linter Feedback The linter flagged many items here because it wasn't being run locally during development. This addresses all of the feedback. * Remove the UnauthorizedGracePeriod NMAgent only defines 102 processing as a temporary status. It's up to consumers of the client to determine whether an unauthorized status means that it should be retried or not. * Add error source to NMA error One of the problems with using the WireserverTransport to modify the http status code is that it obscures the source of those errors. Should there be an issue with NMAgent or Wireserver, it will be difficult (or impossible) to figure out which is which. The error itself should tell you, and WireserverTransport knows which component is responsible. This adds a header to the HTTP response and uses that to communicate the responsible party. This is then wired into the outgoing error so that clients can take appropriate action. * Remove leftover unauthorizedGracePeriod These blocks escaped notice when the rest of the UnauthorizedGracePeriod logic was removed from the nmagent client. * Remove extra validation tag This validation tag wasn't noticed when the validation struct tags were removed in a previous commit. * Add the body to the nmagent.Error When errors are returned, it's useful to have the body available for inspection during debugging efforts. This captures the returned body and makes it available in the nmagent.Error. It's also printed when the error is converted to its string representation. * Remove VirtualNetworkID This was redundant, since VNetID covered the same key. It's actually unclear what would happen in this circumstance if this remained, but since it's incorrect this removes it. * Add StatusCode to error Clients still want to be able to communicate the status code in logs, so this includes the StatusCode there as well. * Add GreKey field to PutNetworkContainerRequest In looking at usages, this `greKey` field is undocumented but critical for certain use cases. This adds it so that it remains supported. * Add periods at the end of all docstrings Docstrings should have punctuation since they're documentation. This adds punctuation to every docstring that is exported (and some that aren't). * Remove unused Option type This was leftover from a previous cleanup commit. * Change `error` to a function The `nmagent.(*Client).error` method wasn't actually using any part of `*Client`. Therefore it should be a function. Since we can't use `error` as a function name because it's a reserved keyword, we're throwing back to the Perl days and calling this one `die`.
This commit is contained in:
Родитель
dbb4f68393
Коммит
94f73740e7
|
@ -0,0 +1,201 @@
|
|||
package nmagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-container-networking/nmagent/internal"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// NewClient returns an initialized Client using the provided configuration.
|
||||
func NewClient(c Config) (*Client, error) {
|
||||
if err := c.Validate(); err != nil {
|
||||
return nil, errors.Wrap(err, "validating config")
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
httpClient: &http.Client{
|
||||
Transport: &internal.WireserverTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
},
|
||||
},
|
||||
host: c.Host,
|
||||
port: c.Port,
|
||||
enableTLS: c.UseTLS,
|
||||
retrier: internal.Retrier{
|
||||
// nolint:gomnd // the base parameter is explained in the function
|
||||
Cooldown: internal.Exponential(1*time.Second, 2),
|
||||
},
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Client is an agent for exchanging information with NMAgent.
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
|
||||
// config
|
||||
host string
|
||||
port uint16
|
||||
|
||||
enableTLS bool
|
||||
|
||||
retrier interface {
|
||||
Do(context.Context, func() error) error
|
||||
}
|
||||
}
|
||||
|
||||
// JoinNetwork joins a node to a customer's virtual network.
|
||||
func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error {
|
||||
req, err := c.buildRequest(ctx, jnr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "building request")
|
||||
}
|
||||
|
||||
err = c.retrier.Do(ctx, func() error {
|
||||
resp, err := c.httpClient.Do(req) // nolint:govet // the shadow is intentional
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "executing request")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return die(resp.StatusCode, resp.Header, resp.Body)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return err // nolint:wrapcheck // wrapping this just introduces noise
|
||||
}
|
||||
|
||||
// GetNetworkConfiguration retrieves the configuration of a customer's virtual
|
||||
// network. Only subnets which have been delegated will be returned.
|
||||
func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkConfigRequest) (VirtualNetwork, error) {
|
||||
var out VirtualNetwork
|
||||
|
||||
req, err := c.buildRequest(ctx, gncr)
|
||||
if err != nil {
|
||||
return out, errors.Wrap(err, "building request")
|
||||
}
|
||||
|
||||
err = c.retrier.Do(ctx, func() error {
|
||||
resp, err := c.httpClient.Do(req) // nolint:govet // the shadow is intentional
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "executing http request to")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return die(resp.StatusCode, resp.Header, resp.Body)
|
||||
}
|
||||
|
||||
ct := resp.Header.Get(internal.HeaderContentType)
|
||||
if ct != internal.MimeJSON {
|
||||
return NewContentError(ct, resp.Body, resp.ContentLength)
|
||||
}
|
||||
|
||||
err = json.NewDecoder(resp.Body).Decode(&out)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "decoding json response")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return out, err // nolint:wrapcheck // wrapping just introduces noise here
|
||||
}
|
||||
|
||||
// PutNetworkContainer applies a Network Container goal state and publishes it
|
||||
// to PubSub.
|
||||
func (c *Client) PutNetworkContainer(ctx context.Context, pncr *PutNetworkContainerRequest) error {
|
||||
req, err := c.buildRequest(ctx, pncr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "building request")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "submitting request")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return die(resp.StatusCode, resp.Header, resp.Body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteNetworkContainer removes a Network Container, its associated IP
|
||||
// addresses, and network policies from an interface.
|
||||
func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainerRequest) error {
|
||||
req, err := c.buildRequest(ctx, dcr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "building request")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "submitting request")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return die(resp.StatusCode, resp.Header, resp.Body)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func die(code int, headers http.Header, body io.ReadCloser) error {
|
||||
// nolint:errcheck // make a best effort to return whatever information we can
|
||||
// returning an error here without the code and source would
|
||||
// be less helpful
|
||||
bodyContent, _ := io.ReadAll(body)
|
||||
return Error{
|
||||
Code: code,
|
||||
// this is a little strange, but the conversion below is to avoid forcing
|
||||
// consumers to depend on an internal type (which they can't anyway)
|
||||
Source: internal.GetErrorSource(headers).String(),
|
||||
Body: bodyContent,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) hostPort() string {
|
||||
port := strconv.Itoa(int(c.port))
|
||||
return net.JoinHostPort(c.host, port)
|
||||
}
|
||||
|
||||
func (c *Client) buildRequest(ctx context.Context, req Request) (*http.Request, error) {
|
||||
if err := req.Validate(); err != nil {
|
||||
return nil, errors.Wrap(err, "validating request")
|
||||
}
|
||||
|
||||
fullURL := &url.URL{
|
||||
Scheme: c.scheme(),
|
||||
Host: c.hostPort(),
|
||||
Path: req.Path(),
|
||||
}
|
||||
|
||||
body, err := req.Body()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "retrieving request body")
|
||||
}
|
||||
|
||||
// nolint:wrapcheck // wrapping doesn't provide useful information
|
||||
return http.NewRequestWithContext(ctx, req.Method(), fullURL.String(), body)
|
||||
}
|
||||
|
||||
func (c *Client) scheme() string {
|
||||
if c.enableTLS {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package nmagent
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Azure/azure-container-networking/nmagent/internal"
|
||||
)
|
||||
|
||||
// NewTestClient is a factory function available in tests only for creating
|
||||
// NMAgent clients with a mock transport
|
||||
func NewTestClient(transport http.RoundTripper) *Client {
|
||||
return &Client{
|
||||
httpClient: &http.Client{
|
||||
Transport: &internal.WireserverTransport{
|
||||
Transport: transport,
|
||||
},
|
||||
},
|
||||
host: "localhost",
|
||||
port: 12345,
|
||||
retrier: internal.Retrier{
|
||||
Cooldown: internal.AsFastAsPossible(),
|
||||
},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,434 @@
|
|||
package nmagent_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-container-networking/nmagent"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var _ http.RoundTripper = &TestTripper{}
|
||||
|
||||
// TestTripper is a RoundTripper with a customizeable RoundTrip method for
|
||||
// testing purposes
|
||||
type TestTripper struct {
|
||||
RoundTripF func(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (t *TestTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.RoundTripF(req)
|
||||
}
|
||||
|
||||
func TestNMAgentClientJoinNetwork(t *testing.T) {
|
||||
joinNetTests := []struct {
|
||||
name string
|
||||
id string
|
||||
exp string
|
||||
respStatus int
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
"happy path",
|
||||
"00000000-0000-0000-0000-000000000000",
|
||||
"/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1",
|
||||
http.StatusOK,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty network ID",
|
||||
"",
|
||||
"",
|
||||
http.StatusOK, // this shouldn't be checked
|
||||
true,
|
||||
},
|
||||
{
|
||||
"internal error",
|
||||
"00000000-0000-0000-0000-000000000000",
|
||||
"/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1",
|
||||
http.StatusInternalServerError,
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range joinNetTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// create a client
|
||||
var got string
|
||||
client := nmagent.NewTestClient(&TestTripper{
|
||||
RoundTripF: func(req *http.Request) (*http.Response, error) {
|
||||
got = req.URL.Path
|
||||
rr := httptest.NewRecorder()
|
||||
_, _ = fmt.Fprintf(rr, `{"httpStatusCode":"%d"}`, test.respStatus)
|
||||
rr.WriteHeader(http.StatusOK)
|
||||
return rr.Result(), nil
|
||||
},
|
||||
})
|
||||
|
||||
// if the test provides a timeout, use it in the context
|
||||
var ctx context.Context
|
||||
if deadline, ok := t.Deadline(); ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithDeadline(context.Background(), deadline)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// attempt to join network
|
||||
// TODO(timraymond): need a more realistic network ID, I think
|
||||
err := client.JoinNetwork(ctx, nmagent.JoinNetworkRequest{test.id})
|
||||
if err != nil && !test.shouldErr {
|
||||
t.Fatal("unexpected error: err:", err)
|
||||
}
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Fatal("expected error but received none")
|
||||
}
|
||||
|
||||
if got != test.exp {
|
||||
t.Error("received URL differs from expectation: got", got, "exp:", test.exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNMAgentClientJoinNetworkRetry(t *testing.T) {
|
||||
// we want to ensure that the client will automatically follow up with
|
||||
// NMAgent, so we want to track the number of requests that it makes
|
||||
invocations := 0
|
||||
exp := 10
|
||||
|
||||
client := nmagent.NewTestClient(&TestTripper{
|
||||
RoundTripF: func(req *http.Request) (*http.Response, error) {
|
||||
rr := httptest.NewRecorder()
|
||||
if invocations < exp {
|
||||
rr.WriteHeader(http.StatusProcessing)
|
||||
invocations++
|
||||
} else {
|
||||
rr.WriteHeader(http.StatusOK)
|
||||
}
|
||||
_, _ = rr.WriteString(`{"httpStatusCode": "200"}`)
|
||||
return rr.Result(), nil
|
||||
},
|
||||
})
|
||||
|
||||
// if the test provides a timeout, use it in the context
|
||||
var ctx context.Context
|
||||
if deadline, ok := t.Deadline(); ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithDeadline(context.Background(), deadline)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// attempt to join network
|
||||
err := client.JoinNetwork(ctx, nmagent.JoinNetworkRequest{"00000000-0000-0000-0000-000000000000"})
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error: err:", err)
|
||||
}
|
||||
|
||||
if invocations != exp {
|
||||
t.Error("client did not make the expected number of API calls: got:", invocations, "exp:", exp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSError(t *testing.T) {
|
||||
const wsError string = `
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Error xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:xsd="http://w
|
||||
ww.w3.org/2001/XMLSchema">
|
||||
<Code>InternalError</Code>
|
||||
<Message>The server encountered an internal error. Please retry the request.
|
||||
</Message>
|
||||
<Details></Details>
|
||||
</Error>
|
||||
`
|
||||
|
||||
client := nmagent.NewTestClient(&TestTripper{
|
||||
RoundTripF: func(req *http.Request) (*http.Response, error) {
|
||||
rr := httptest.NewRecorder()
|
||||
rr.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = rr.WriteString(wsError)
|
||||
return rr.Result(), nil
|
||||
},
|
||||
})
|
||||
|
||||
req := nmagent.GetNetworkConfigRequest{
|
||||
VNetID: "4815162342",
|
||||
}
|
||||
_, err := client.GetNetworkConfiguration(context.TODO(), req)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error to not be nil")
|
||||
}
|
||||
|
||||
var cerr nmagent.Error
|
||||
ok := errors.As(err, &cerr)
|
||||
if !ok {
|
||||
t.Fatal("error was not an nmagent.Error")
|
||||
}
|
||||
|
||||
t.Log(cerr.Error())
|
||||
if !strings.Contains(cerr.Error(), "InternalError") {
|
||||
t.Error("error did not contain the error content from wireserver")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNMAgentGetNetworkConfig(t *testing.T) {
|
||||
getTests := []struct {
|
||||
name string
|
||||
vnetID string
|
||||
expURL string
|
||||
expVNet map[string]interface{}
|
||||
shouldCall bool
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
"happy path",
|
||||
"00000000-0000-0000-0000-000000000000",
|
||||
"/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1",
|
||||
map[string]interface{}{
|
||||
"httpStatusCode": "200",
|
||||
"cnetSpace": "10.10.1.0/24",
|
||||
"defaultGateway": "10.10.0.1",
|
||||
"dnsServers": []string{
|
||||
"1.1.1.1",
|
||||
"1.0.0.1",
|
||||
},
|
||||
"subnets": []map[string]interface{}{},
|
||||
"vnetSpace": "10.0.0.0/8",
|
||||
"vnetVersion": "12345",
|
||||
},
|
||||
true,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range getTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var got string
|
||||
client := nmagent.NewTestClient(&TestTripper{
|
||||
RoundTripF: func(req *http.Request) (*http.Response, error) {
|
||||
rr := httptest.NewRecorder()
|
||||
got = req.URL.Path
|
||||
rr.WriteHeader(http.StatusOK)
|
||||
err := json.NewEncoder(rr).Encode(&test.expVNet)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "encoding response")
|
||||
}
|
||||
|
||||
return rr.Result(), nil
|
||||
},
|
||||
})
|
||||
|
||||
// if the test provides a timeout, use it in the context
|
||||
var ctx context.Context
|
||||
if deadline, ok := t.Deadline(); ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithDeadline(context.Background(), deadline)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
gotVNet, err := client.GetNetworkConfiguration(ctx, nmagent.GetNetworkConfigRequest{test.vnetID})
|
||||
if err != nil && !test.shouldErr {
|
||||
t.Fatal("unexpected error: err:", err)
|
||||
}
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Fatal("expected error but received none")
|
||||
}
|
||||
|
||||
if got != test.expURL && test.shouldCall {
|
||||
t.Error("unexpected URL: got:", got, "exp:", test.expURL)
|
||||
}
|
||||
|
||||
// TODO(timraymond): this is ugly
|
||||
expVnet := nmagent.VirtualNetwork{
|
||||
CNetSpace: test.expVNet["cnetSpace"].(string),
|
||||
DefaultGateway: test.expVNet["defaultGateway"].(string),
|
||||
DNSServers: test.expVNet["dnsServers"].([]string),
|
||||
Subnets: []nmagent.Subnet{},
|
||||
VNetSpace: test.expVNet["vnetSpace"].(string),
|
||||
VNetVersion: test.expVNet["vnetVersion"].(string),
|
||||
}
|
||||
if !cmp.Equal(gotVNet, expVnet) {
|
||||
t.Error("received vnet differs from expected: diff:", cmp.Diff(gotVNet, expVnet))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNMAgentGetNetworkConfigRetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
count := 0
|
||||
exp := 10
|
||||
client := nmagent.NewTestClient(&TestTripper{
|
||||
RoundTripF: func(req *http.Request) (*http.Response, error) {
|
||||
rr := httptest.NewRecorder()
|
||||
if count < exp {
|
||||
rr.WriteHeader(http.StatusProcessing)
|
||||
count++
|
||||
} else {
|
||||
rr.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// we still need a fake response
|
||||
_, _ = rr.WriteString(`{"httpStatusCode": "200"}`)
|
||||
return rr.Result(), nil
|
||||
},
|
||||
})
|
||||
|
||||
// if the test provides a timeout, use it in the context
|
||||
var ctx context.Context
|
||||
if deadline, ok := t.Deadline(); ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithDeadline(context.Background(), deadline)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
_, err := client.GetNetworkConfiguration(ctx, nmagent.GetNetworkConfigRequest{"00000000-0000-0000-0000-000000000000"})
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error: err:", err)
|
||||
}
|
||||
|
||||
if count != exp {
|
||||
t.Error("unexpected number of API calls: exp:", exp, "got:", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNMAgentPutNetworkContainer(t *testing.T) {
|
||||
putNCTests := []struct {
|
||||
name string
|
||||
req *nmagent.PutNetworkContainerRequest
|
||||
shouldCall bool
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
"happy path",
|
||||
&nmagent.PutNetworkContainerRequest{
|
||||
ID: "350f1e3c-4283-4f51-83a1-c44253962ef1",
|
||||
Version: uint64(12345),
|
||||
VNetID: "be3a33e-61e3-42c7-bd23-6b949f57bd36",
|
||||
SubnetName: "TestSubnet",
|
||||
IPv4Addrs: []string{"10.0.0.43"},
|
||||
Policies: []nmagent.Policy{
|
||||
{
|
||||
ID: "policyID1",
|
||||
Type: "type1",
|
||||
},
|
||||
{
|
||||
ID: "policyID2",
|
||||
Type: "type2",
|
||||
},
|
||||
},
|
||||
VlanID: 1234,
|
||||
AuthenticationToken: "swordfish",
|
||||
PrimaryAddress: "10.0.0.1",
|
||||
},
|
||||
true,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range putNCTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
didCall := false
|
||||
client := nmagent.NewTestClient(&TestTripper{
|
||||
RoundTripF: func(req *http.Request) (*http.Response, error) {
|
||||
rr := httptest.NewRecorder()
|
||||
_, _ = rr.WriteString(`{"httpStatusCode": "200"}`)
|
||||
rr.WriteHeader(http.StatusOK)
|
||||
didCall = true
|
||||
return rr.Result(), nil
|
||||
},
|
||||
})
|
||||
|
||||
err := client.PutNetworkContainer(context.TODO(), test.req)
|
||||
if err != nil && !test.shouldErr {
|
||||
t.Fatal("unexpected error: err", err)
|
||||
}
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Fatal("expected error but received none")
|
||||
}
|
||||
|
||||
if test.shouldCall && !didCall {
|
||||
t.Fatal("expected call but received none")
|
||||
}
|
||||
|
||||
if !test.shouldCall && didCall {
|
||||
t.Fatal("unexpected call. expected no call ")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNMAgentDeleteNC(t *testing.T) {
|
||||
deleteTests := []struct {
|
||||
name string
|
||||
req nmagent.DeleteContainerRequest
|
||||
exp string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
"happy path",
|
||||
nmagent.DeleteContainerRequest{
|
||||
NCID: "00000000-0000-0000-0000-000000000000",
|
||||
PrimaryAddress: "10.0.0.1",
|
||||
AuthenticationToken: "swordfish",
|
||||
},
|
||||
"/machine/plugins/?comp=nmagent&type=NetworkManagement/interfaces/10.0.0.1/networkContainers/00000000-0000-0000-0000-000000000000/authenticationToken/swordfish/api-version/1/method/DELETE",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
var got string
|
||||
for _, test := range deleteTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
client := nmagent.NewTestClient(&TestTripper{
|
||||
RoundTripF: func(req *http.Request) (*http.Response, error) {
|
||||
got = req.URL.Path
|
||||
rr := httptest.NewRecorder()
|
||||
_, _ = rr.WriteString(`{"httpStatusCode": "200"}`)
|
||||
return rr.Result(), nil
|
||||
},
|
||||
})
|
||||
|
||||
err := client.DeleteNetworkContainer(context.TODO(), test.req)
|
||||
if err != nil && !test.shouldErr {
|
||||
t.Fatal("unexpected error: err:", err)
|
||||
}
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Fatal("expected error but received none")
|
||||
}
|
||||
|
||||
if test.exp != got {
|
||||
t.Errorf("received URL differs from expectation:\n\texp: %q:\n\tgot: %q", test.exp, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
package nmagent
|
||||
|
||||
import "github.com/Azure/azure-container-networking/nmagent/internal"
|
||||
|
||||
// Config is a configuration for an NMAgent Client.
|
||||
type Config struct {
|
||||
/////////////////////
|
||||
// Required Config //
|
||||
/////////////////////
|
||||
Host string // the host the client will connect to
|
||||
Port uint16 // the port the client will connect to
|
||||
|
||||
/////////////////////
|
||||
// Optional Config //
|
||||
/////////////////////
|
||||
UseTLS bool // forces all connections to use TLS
|
||||
}
|
||||
|
||||
// Validate reports whether this configuration is a valid configuration for a
|
||||
// client.
|
||||
func (c Config) Validate() error {
|
||||
err := internal.ValidationError{}
|
||||
|
||||
if c.Host == "" {
|
||||
err.MissingFields = append(err.MissingFields, "Host")
|
||||
}
|
||||
|
||||
if c.Port == 0 {
|
||||
err.MissingFields = append(err.MissingFields, "Port")
|
||||
}
|
||||
|
||||
if err.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package nmagent_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-container-networking/nmagent"
|
||||
)
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
configTests := []struct {
|
||||
name string
|
||||
config nmagent.Config
|
||||
expValid bool
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
nmagent.Config{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing port",
|
||||
nmagent.Config{
|
||||
Host: "localhost",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing host",
|
||||
nmagent.Config{
|
||||
Port: 12345,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"complete",
|
||||
nmagent.Config{
|
||||
Host: "localhost",
|
||||
Port: 12345,
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range configTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := test.config.Validate()
|
||||
if err != nil && test.expValid {
|
||||
t.Fatal("expected config to be valid but wasnt: err:", err)
|
||||
}
|
||||
|
||||
if err == nil && !test.expValid {
|
||||
t.Fatal("expected config to be invalid but wasn't")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
// package nmagent contains types and functions necessary for interacting with
|
||||
// the Network Manager Agent (NMAgent).
|
||||
package nmagent
|
|
@ -0,0 +1,105 @@
|
|||
package nmagent
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/Azure/azure-container-networking/nmagent/internal"
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ContentError is encountered when an unexpected content type is obtained from
|
||||
// NMAgent.
|
||||
type ContentError struct {
|
||||
Type string // the mime type of the content received
|
||||
Body []byte // the received body
|
||||
}
|
||||
|
||||
func (c ContentError) Error() string {
|
||||
if c.Type == internal.MimeOctetStream {
|
||||
return fmt.Sprintf("unexpected content type %q: body length: %d", c.Type, len(c.Body))
|
||||
}
|
||||
return fmt.Sprintf("unexpected content type %q: body: %s", c.Type, c.Body)
|
||||
}
|
||||
|
||||
// NewContentError creates a ContentError from a provided reader and limit.
|
||||
func NewContentError(contentType string, in io.Reader, limit int64) error {
|
||||
out := ContentError{
|
||||
Type: contentType,
|
||||
Body: make([]byte, limit),
|
||||
}
|
||||
|
||||
bodyReader := io.LimitReader(in, limit)
|
||||
|
||||
read, err := io.ReadFull(bodyReader, out.Body)
|
||||
earlyEOF := errors.Is(err, io.ErrUnexpectedEOF)
|
||||
if err != nil && !earlyEOF {
|
||||
return pkgerrors.Wrap(err, "reading unexpected content body")
|
||||
}
|
||||
|
||||
if earlyEOF {
|
||||
out.Body = out.Body[:read]
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// Error is a aberrent condition encountered when interacting with the NMAgent
|
||||
// API.
|
||||
type Error struct {
|
||||
Code int // the HTTP status code received
|
||||
Source string // the component responsible for producing the error
|
||||
Body []byte // the body of the error returned
|
||||
}
|
||||
|
||||
// Error constructs a string representation of this error in accordance with
|
||||
// the error interface.
|
||||
func (e Error) Error() string {
|
||||
return fmt.Sprintf("nmagent: %s: http status %d: %s: body: %s", e.source(), e.Code, e.Message(), string(e.Body))
|
||||
}
|
||||
|
||||
func (e Error) source() string {
|
||||
source := "not provided"
|
||||
if e.Source != "" {
|
||||
source = e.Source
|
||||
}
|
||||
return fmt.Sprintf("source: %s", source)
|
||||
}
|
||||
|
||||
// Message interprets the HTTP Status code from NMAgent and returns the
|
||||
// corresponding explanation from the documentation.
|
||||
func (e Error) Message() string {
|
||||
switch e.Code {
|
||||
case http.StatusProcessing:
|
||||
return "the request is taking time to process. the caller should try the request again"
|
||||
case http.StatusUnauthorized:
|
||||
return "the request did not originate from an interface with an OwningServiceInstanceId property"
|
||||
case http.StatusInternalServerError:
|
||||
return "error occurred during nmagent's request processing"
|
||||
default:
|
||||
return "undocumented error"
|
||||
}
|
||||
}
|
||||
|
||||
// Temporary reports whether the error encountered from NMAgent should be
|
||||
// considered temporary, and thus retriable.
|
||||
func (e Error) Temporary() bool {
|
||||
// NMAgent will return a 102 (Processing) if the request is taking time to
|
||||
// complete. These should be attempted again.
|
||||
return e.Code == http.StatusProcessing
|
||||
}
|
||||
|
||||
// StatusCode returns the HTTP status associated with this error.
|
||||
func (e Error) StatusCode() int {
|
||||
return e.Code
|
||||
}
|
||||
|
||||
// Unauthorized reports whether the error was produced as a result of
|
||||
// submitting the request from an interface without an OwningServiceInstanceId
|
||||
// property. In some cases, this can be a transient condition that could be
|
||||
// retried.
|
||||
func (e Error) Unauthorized() bool {
|
||||
return e.Code == http.StatusUnauthorized
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
package internal
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Error represents an internal sentinal error which can be defined as a
|
||||
// constant.
|
||||
type Error string
|
||||
|
||||
func (e Error) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// ErrorSource is an indicator used as a header value to indicate the source of
|
||||
// non-2xx status codes.
|
||||
type ErrorSource int
|
||||
|
||||
const (
|
||||
ErrorSourceInvalid ErrorSource = iota
|
||||
ErrorSourceWireserver
|
||||
ErrorSourceNMAgent
|
||||
)
|
||||
|
||||
// String produces the string equivalent for the ErrorSource type.
|
||||
func (e ErrorSource) String() string {
|
||||
switch e {
|
||||
case ErrorSourceWireserver:
|
||||
return "wireserver"
|
||||
case ErrorSourceNMAgent:
|
||||
return "nmagent"
|
||||
case ErrorSourceInvalid:
|
||||
return ""
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorSource produces an ErrorSource value from the provided string. Any
|
||||
// unrecognized values will become the invalid type.
|
||||
func NewErrorSource(es string) ErrorSource {
|
||||
switch es {
|
||||
case "wireserver":
|
||||
return ErrorSourceWireserver
|
||||
case "nmagent":
|
||||
return ErrorSourceNMAgent
|
||||
default:
|
||||
return ErrorSourceInvalid
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
HeaderErrorSource = "X-Error-Source"
|
||||
)
|
||||
|
||||
// GetErrorSource retrieves the error source from the provided HTTP headers.
|
||||
func GetErrorSource(head http.Header) ErrorSource {
|
||||
return NewErrorSource(head.Get(HeaderErrorSource))
|
||||
}
|
||||
|
||||
// SetErrorSource sets the header value necessary for communicating the error
|
||||
// source.
|
||||
func SetErrorSource(head *http.Header, es ErrorSource) {
|
||||
head.Set(HeaderErrorSource, es.String())
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrorSource(t *testing.T) {
|
||||
esTests := []struct {
|
||||
sub string
|
||||
exp string
|
||||
}{
|
||||
{"wireserver", "wireserver"},
|
||||
{"nmagent", "nmagent"},
|
||||
{"garbage", ""},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, test := range esTests {
|
||||
test := test
|
||||
t.Run(test.sub, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// since this is intended for use with headers, this tests end-to-end
|
||||
es := NewErrorSource(test.sub)
|
||||
|
||||
head := http.Header{}
|
||||
SetErrorSource(&head, es)
|
||||
gotEs := GetErrorSource(head)
|
||||
|
||||
got := gotEs.String()
|
||||
|
||||
if test.exp != got {
|
||||
t.Fatal("received value differs from expectation: exp:", test, "got:", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
package internal
|
||||
|
||||
const (
|
||||
HeaderContentType = "Content-Type"
|
||||
)
|
||||
|
||||
const (
|
||||
MimeJSON = "application/json"
|
||||
MimeOctetStream = "application/octet-stream"
|
||||
)
|
|
@ -0,0 +1,125 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
noDelay = 0 * time.Nanosecond
|
||||
)
|
||||
|
||||
const (
|
||||
ErrMaxAttempts = Error("maximum attempts reached")
|
||||
)
|
||||
|
||||
// TemporaryError is an error that can indicate whether it may be resolved with
|
||||
// another attempt.
|
||||
type TemporaryError interface {
|
||||
error
|
||||
Temporary() bool
|
||||
}
|
||||
|
||||
// Retrier is a construct for attempting some operation multiple times with a
|
||||
// configurable backoff strategy.
|
||||
type Retrier struct {
|
||||
Cooldown CooldownFactory
|
||||
}
|
||||
|
||||
// Do repeatedly invokes the provided run function while the context remains
|
||||
// active. It waits in between invocations of the provided functions by
|
||||
// delegating to the provided Cooldown function.
|
||||
func (r Retrier) Do(ctx context.Context, run func() error) error {
|
||||
cooldown := r.Cooldown()
|
||||
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
// nolint:wrapcheck // no meaningful information can be added to this error
|
||||
return err
|
||||
}
|
||||
|
||||
err := run()
|
||||
if err != nil {
|
||||
// check to see if it's temporary.
|
||||
var tempErr TemporaryError
|
||||
if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() {
|
||||
delay, err := cooldown() // nolint:govet // the shadow is intentional
|
||||
if err != nil {
|
||||
return pkgerrors.Wrap(err, "sleeping during retry")
|
||||
}
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
}
|
||||
|
||||
// since it's not temporary, it can't be retried, so...
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// CooldownFunc is a function that will block when called. It is intended for
|
||||
// use with retry logic.
|
||||
type CooldownFunc func() (time.Duration, error)
|
||||
|
||||
// CooldownFactory is a function that returns CooldownFuncs. It helps
|
||||
// CooldownFuncs dispose of any accumulated state so that they function
|
||||
// correctly upon successive uses.
|
||||
type CooldownFactory func() CooldownFunc
|
||||
|
||||
// Max provides a fixed limit for the number of times a subordinate cooldown
|
||||
// function can be invoked.
|
||||
func Max(limit int, factory CooldownFactory) CooldownFactory {
|
||||
return func() CooldownFunc {
|
||||
cooldown := factory()
|
||||
count := 0
|
||||
return func() (time.Duration, error) {
|
||||
if count >= limit {
|
||||
return noDelay, ErrMaxAttempts
|
||||
}
|
||||
|
||||
delay, err := cooldown()
|
||||
if err != nil {
|
||||
return noDelay, err
|
||||
}
|
||||
count++
|
||||
return delay, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AsFastAsPossible is a Cooldown strategy that does not block, allowing retry
|
||||
// logic to proceed as fast as possible. This is particularly useful in tests.
|
||||
func AsFastAsPossible() CooldownFactory {
|
||||
return func() CooldownFunc {
|
||||
return func() (time.Duration, error) {
|
||||
return noDelay, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exponential provides an exponential increase the the base interval provided.
|
||||
func Exponential(interval time.Duration, base int) CooldownFactory {
|
||||
return func() CooldownFunc {
|
||||
count := 0
|
||||
return func() (time.Duration, error) {
|
||||
increment := math.Pow(float64(base), float64(count))
|
||||
delay := interval.Nanoseconds() * int64(increment)
|
||||
count++
|
||||
return time.Duration(delay), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fixed produced the same delay value upon each invocation.
|
||||
func Fixed(delay time.Duration) CooldownFactory {
|
||||
return func() CooldownFunc {
|
||||
return func() (time.Duration, error) {
|
||||
return delay, nil
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ExampleExponential() {
|
||||
// this example details the common case where the powers of 2 are desired
|
||||
cooldown := Exponential(1*time.Millisecond, 2)()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
got, err := cooldown()
|
||||
if err != nil {
|
||||
fmt.Println("received error during cooldown: err:", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println(got)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// 1ms
|
||||
// 2ms
|
||||
// 4ms
|
||||
// 8ms
|
||||
// 16ms
|
||||
}
|
||||
|
||||
func ExampleFixed() {
|
||||
cooldown := Fixed(10 * time.Millisecond)()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
got, err := cooldown()
|
||||
if err != nil {
|
||||
fmt.Println("unexpected error cooling down: err", err)
|
||||
return
|
||||
}
|
||||
fmt.Println(got)
|
||||
|
||||
// Output:
|
||||
// 10ms
|
||||
// 10ms
|
||||
// 10ms
|
||||
// 10ms
|
||||
// 10ms
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleMax() {
|
||||
cooldown := Max(4, Fixed(10*time.Millisecond))()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
got, err := cooldown()
|
||||
if err != nil {
|
||||
fmt.Println("error cooling down:", err)
|
||||
break
|
||||
}
|
||||
fmt.Println(got)
|
||||
|
||||
// Output:
|
||||
// 10ms
|
||||
// 10ms
|
||||
// 10ms
|
||||
// 10ms
|
||||
// error cooling down: maximum attempts reached
|
||||
}
|
||||
}
|
|
@ -0,0 +1,164 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TestError struct{}
|
||||
|
||||
func (t TestError) Error() string {
|
||||
return "oh no!"
|
||||
}
|
||||
|
||||
func (t TestError) Temporary() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func TestBackoffRetry(t *testing.T) {
|
||||
got := 0
|
||||
exp := 10
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
rt := Retrier{
|
||||
Cooldown: AsFastAsPossible(),
|
||||
}
|
||||
|
||||
err := rt.Do(ctx, func() error {
|
||||
if got < exp {
|
||||
got++
|
||||
return TestError{}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error: err:", err)
|
||||
}
|
||||
|
||||
if got < exp {
|
||||
t.Error("unexpected number of invocations: got:", got, "exp:", exp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackoffRetryWithCancel(t *testing.T) {
|
||||
got := 0
|
||||
exp := 5
|
||||
total := 10
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
rt := Retrier{
|
||||
Cooldown: AsFastAsPossible(),
|
||||
}
|
||||
|
||||
err := rt.Do(ctx, func() error {
|
||||
got++
|
||||
if got >= exp {
|
||||
cancel()
|
||||
}
|
||||
|
||||
if got < total {
|
||||
return TestError{}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected context cancellation error, but received none")
|
||||
}
|
||||
|
||||
if got != exp {
|
||||
t.Error("unexpected number of iterations: exp:", exp, "got:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackoffRetryUnretriableError(t *testing.T) {
|
||||
rt := Retrier{
|
||||
Cooldown: AsFastAsPossible(),
|
||||
}
|
||||
|
||||
err := rt.Do(context.Background(), func() error {
|
||||
return errors.New("boom") // nolint:goerr113 // it's just a test
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, but none was returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixed(t *testing.T) {
|
||||
exp := 20 * time.Millisecond
|
||||
|
||||
cooldown := Fixed(exp)()
|
||||
|
||||
got, err := cooldown()
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error invoking cooldown: err:", err)
|
||||
}
|
||||
|
||||
if got != exp {
|
||||
t.Fatal("unexpected sleep duration: exp:", exp, "got:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExp(t *testing.T) {
|
||||
exp := 10 * time.Millisecond
|
||||
base := 2
|
||||
|
||||
cooldown := Exponential(exp, base)()
|
||||
|
||||
first, err := cooldown()
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error invoking cooldown: err:", err)
|
||||
}
|
||||
|
||||
if first != exp {
|
||||
t.Fatal("unexpected sleep during first cooldown: exp:", exp, "got:", first)
|
||||
}
|
||||
|
||||
// ensure that the sleep increases
|
||||
second, err := cooldown()
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error on second invocation of cooldown: err:", err)
|
||||
}
|
||||
|
||||
if second < first {
|
||||
t.Fatal("unexpected sleep during first cooldown: exp:", exp, "got:", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMax(t *testing.T) {
|
||||
exp := 10
|
||||
got := 0
|
||||
|
||||
// create a test sleep function
|
||||
fn := func() CooldownFunc {
|
||||
return func() (time.Duration, error) {
|
||||
got++
|
||||
return 0 * time.Nanosecond, nil
|
||||
}
|
||||
}
|
||||
|
||||
cooldown := Max(10, fn)()
|
||||
|
||||
for i := 0; i < exp; i++ {
|
||||
_, err := cooldown()
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error from cooldown: err:", err)
|
||||
}
|
||||
}
|
||||
|
||||
if exp != got {
|
||||
t.Error("unexpected number of cooldown invocations: exp:", exp, "got:", got)
|
||||
}
|
||||
|
||||
// attempt one more, we expect an error
|
||||
_, err := cooldown()
|
||||
if err == nil {
|
||||
t.Errorf("expected an error after %d invocations but received none", exp+1)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ValidationError struct {
|
||||
MissingFields []string
|
||||
}
|
||||
|
||||
func (v ValidationError) Error() string {
|
||||
return fmt.Sprintf("missing fields: %s", strings.Join(v.MissingFields, ", "))
|
||||
}
|
||||
|
||||
func (v ValidationError) IsEmpty() bool {
|
||||
return len(v.MissingFields) == 0
|
||||
}
|
|
@ -0,0 +1,184 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// nolint:gomnd // constantizing just obscures meaning here
|
||||
_ int64 = 1 << (10 * iota)
|
||||
kilobyte
|
||||
// megabyte
|
||||
)
|
||||
|
||||
const (
|
||||
WirePrefix string = "/machine/plugins/?comp=nmagent&type="
|
||||
|
||||
// DefaultBufferSize is the maximum number of bytes read from Wireserver in
|
||||
// the event that no Content-Length is provided. The responses are relatively
|
||||
// small, so the smallest page size should be sufficient
|
||||
DefaultBufferSize int64 = 4 * kilobyte
|
||||
|
||||
// errors
|
||||
ErrNoStatusCode = Error("no httpStatusCode property returned in Wireserver response")
|
||||
)
|
||||
|
||||
var _ http.RoundTripper = &WireserverTransport{}
|
||||
|
||||
// WireserverResponse represents a raw response from Wireserver.
|
||||
type WireserverResponse map[string]json.RawMessage
|
||||
|
||||
// StatusCode extracts the embedded HTTP status code from the response from Wireserver.
|
||||
func (w WireserverResponse) StatusCode() (int, error) {
|
||||
if status, ok := w["httpStatusCode"]; ok {
|
||||
var statusStr string
|
||||
err := json.Unmarshal(status, &statusStr)
|
||||
if err != nil {
|
||||
return 0, pkgerrors.Wrap(err, "unmarshaling httpStatusCode from Wireserver")
|
||||
}
|
||||
|
||||
code, err := strconv.Atoi(statusStr)
|
||||
if err != nil {
|
||||
return code, pkgerrors.Wrap(err, "parsing http status code from wireserver")
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
return 0, ErrNoStatusCode
|
||||
}
|
||||
|
||||
// WireserverTransport is an http.RoundTripper that applies transformation
|
||||
// rules to inbound requests necessary to make them compatible with Wireserver.
|
||||
type WireserverTransport struct {
|
||||
Transport http.RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip executes arbitrary HTTP requests against Wireserver while applying
|
||||
// the necessary transformation rules to make such requests acceptable to
|
||||
// Wireserver.
|
||||
func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, error) {
|
||||
// RoundTrippers are not allowed to modify the request, so we clone it here.
|
||||
// We need to extract the context from the request first since this is _not_
|
||||
// cloned. The dependent Wireserver request should have the same deadline and
|
||||
// cancellation properties as the inbound request though, hence the reuse.
|
||||
ctx := inReq.Context()
|
||||
req := inReq.Clone(ctx)
|
||||
|
||||
// the original path of the request must be prefixed with wireserver's path
|
||||
path := WirePrefix
|
||||
if req.URL.Path != "" {
|
||||
path += req.URL.Path[1:]
|
||||
}
|
||||
|
||||
// the query string from the request must have its constituent parts (?,=,&)
|
||||
// transformed to slashes and appended to the query
|
||||
if req.URL.RawQuery != "" {
|
||||
query := req.URL.RawQuery
|
||||
query = strings.ReplaceAll(query, "?", "/")
|
||||
query = strings.ReplaceAll(query, "=", "/")
|
||||
query = strings.ReplaceAll(query, "&", "/")
|
||||
path += "/" + query
|
||||
}
|
||||
|
||||
req.URL.Path = path
|
||||
|
||||
// wireserver cannot tolerate PUT requests, so it's necessary to transform
|
||||
// those to POSTs
|
||||
if req.Method == http.MethodPut {
|
||||
req.Method = http.MethodPost
|
||||
}
|
||||
|
||||
// all POST requests (and by extension, PUT) must have a non-nil body
|
||||
if req.Method == http.MethodPost && req.Body == nil {
|
||||
req.Body = io.NopCloser(strings.NewReader(""))
|
||||
}
|
||||
|
||||
// execute the request to the downstream transport
|
||||
resp, err := w.Transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, pkgerrors.Wrap(err, "executing request to wireserver")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// something happened at Wireserver, so set a header implicating Wireserver
|
||||
// and hand the response back up
|
||||
SetErrorSource(&resp.Header, ErrorSourceWireserver)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// at this point we're definitely going to modify the body, so we want to
|
||||
// make sure we close the original request's body, since we're going to
|
||||
// replace it
|
||||
defer func(body io.ReadCloser) {
|
||||
body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
// buffer the entire response from Wireserver
|
||||
clen := resp.ContentLength
|
||||
if clen < 0 {
|
||||
clen = DefaultBufferSize
|
||||
}
|
||||
|
||||
body := make([]byte, clen)
|
||||
bodyReader := io.LimitReader(resp.Body, clen)
|
||||
|
||||
numRead, err := io.ReadFull(bodyReader, body)
|
||||
if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, pkgerrors.Wrap(err, "reading response from wireserver")
|
||||
}
|
||||
// it's entirely possible at this point that we read less than we allocated,
|
||||
// so trim the slice back for decoding
|
||||
body = body[:numRead]
|
||||
|
||||
// set the content length properly in case it wasn't set. If it was, this is
|
||||
// effectively a no-op
|
||||
resp.ContentLength = int64(numRead)
|
||||
|
||||
// it's unclear whether Wireserver sets Content-Type appropriately, so we
|
||||
// attempt to decode it first and surface it otherwise.
|
||||
var wsResp WireserverResponse
|
||||
err = json.Unmarshal(body, &wsResp)
|
||||
if err != nil {
|
||||
// probably not JSON, so figure out what it is, pack it up, and surface it
|
||||
// unmodified
|
||||
resp.Header.Set(HeaderContentType, http.DetectContentType(body))
|
||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
// nolint:nilerr // we effectively "fix" this error because it's expected
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// we know that it's JSON now, so communicate that upwards
|
||||
resp.Header.Set(HeaderContentType, MimeJSON)
|
||||
|
||||
// set the response status code with the *real* status code
|
||||
realCode, err := wsResp.StatusCode()
|
||||
if err != nil {
|
||||
return resp, pkgerrors.Wrap(err, "retrieving status code from wireserver response")
|
||||
}
|
||||
|
||||
// add the advisory header stating that any HTTP Status from here out is from
|
||||
// NMAgent
|
||||
SetErrorSource(&resp.Header, ErrorSourceNMAgent)
|
||||
|
||||
resp.StatusCode = realCode
|
||||
|
||||
// re-encode the body and re-attach it to the response
|
||||
delete(wsResp, "httpStatusCode") // TODO(timraymond): concern of the response
|
||||
|
||||
outBody, err := json.Marshal(wsResp)
|
||||
if err != nil {
|
||||
return resp, pkgerrors.Wrap(err, "re-encoding json response from wireserver")
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(outBody))
|
||||
|
||||
return resp, nil
|
||||
}
|
|
@ -0,0 +1,393 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type TestTripper struct {
|
||||
// TODO(timraymond): this entire struct is duplicated
|
||||
RoundTripF func(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (t *TestTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.RoundTripF(req)
|
||||
}
|
||||
|
||||
func TestWireserverTransportPathTransform(t *testing.T) {
|
||||
// Wireserver introduces specific rules on how requests should be
|
||||
// transformed. This test ensures we got those correct.
|
||||
|
||||
pathTests := []struct {
|
||||
name string
|
||||
method string
|
||||
sub string
|
||||
exp string
|
||||
}{
|
||||
{
|
||||
"happy path",
|
||||
http.MethodGet,
|
||||
"/test/path",
|
||||
"/machine/plugins/?comp=nmagent&type=test/path",
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
http.MethodGet,
|
||||
"",
|
||||
"/machine/plugins/?comp=nmagent&type=",
|
||||
},
|
||||
{
|
||||
"monopath",
|
||||
http.MethodGet,
|
||||
"/foo",
|
||||
"/machine/plugins/?comp=nmagent&type=foo",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range pathTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var got string
|
||||
client := &http.Client{
|
||||
Transport: &WireserverTransport{
|
||||
Transport: &TestTripper{
|
||||
RoundTripF: func(r *http.Request) (*http.Response, error) {
|
||||
got = r.URL.Path
|
||||
rr := httptest.NewRecorder()
|
||||
rr.WriteHeader(http.StatusOK)
|
||||
_, _ = rr.WriteString(`{"httpStatusCode": "200"}`)
|
||||
return rr.Result(), nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// execute
|
||||
|
||||
//nolint:noctx // just a test
|
||||
req, err := http.NewRequest(test.method, test.sub, http.NoBody)
|
||||
if err != nil {
|
||||
t.Fatal("error creating new request: err:", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error submitting request: err:", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// assert
|
||||
if got != test.exp {
|
||||
t.Error("received path differs from expectation: exp:", test.exp, "got:", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireserverTransportStatusTransform(t *testing.T) {
|
||||
// Wireserver only responds with 200 or 400 and embeds the actual status code
|
||||
// in JSON. The Transport should correct this and return the actual status as
|
||||
// an actual status
|
||||
|
||||
statusTests := []struct {
|
||||
name string
|
||||
response map[string]interface{}
|
||||
expBody map[string]interface{}
|
||||
expStatus int
|
||||
}{
|
||||
{
|
||||
"401",
|
||||
map[string]interface{}{
|
||||
"httpStatusCode": "401",
|
||||
},
|
||||
map[string]interface{}{},
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"200 with body",
|
||||
map[string]interface{}{
|
||||
"httpStatusCode": "200",
|
||||
"some": "data",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"some": "data",
|
||||
},
|
||||
http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range statusTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &WireserverTransport{
|
||||
Transport: &TestTripper{
|
||||
RoundTripF: func(r *http.Request) (*http.Response, error) {
|
||||
rr := httptest.NewRecorder()
|
||||
// mimic Wireserver handing back a 200 regardless:
|
||||
rr.WriteHeader(http.StatusOK)
|
||||
|
||||
err := json.NewEncoder(rr).Encode(&test.response)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "encoding json response")
|
||||
}
|
||||
|
||||
return rr.Result(), nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// execute
|
||||
|
||||
// nolint:noctx // just a test
|
||||
req, err := http.NewRequest(http.MethodGet, "/test/path", http.NoBody)
|
||||
if err != nil {
|
||||
t.Fatal("error creating new request: err:", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error submitting request: err:", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// assert
|
||||
gotStatus := resp.StatusCode
|
||||
if gotStatus != test.expStatus {
|
||||
t.Errorf("status codes differ: exp: (%d) %s: got (%d): %s", test.expStatus, http.StatusText(test.expStatus), gotStatus, http.StatusText(gotStatus))
|
||||
}
|
||||
|
||||
var gotBody map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&gotBody)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error decoding json body: err:", err)
|
||||
}
|
||||
|
||||
if !cmp.Equal(test.expBody, gotBody) {
|
||||
t.Error("received body differs from expected: diff:", cmp.Diff(test.expBody, gotBody))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireserverTransportPutPost(t *testing.T) {
|
||||
// wireserver can't tolerate PUT requests, so they must be transformed to POSTs
|
||||
t.Parallel()
|
||||
|
||||
var got string
|
||||
client := &http.Client{
|
||||
Transport: &WireserverTransport{
|
||||
Transport: &TestTripper{
|
||||
RoundTripF: func(req *http.Request) (*http.Response, error) {
|
||||
got = req.Method
|
||||
rr := httptest.NewRecorder()
|
||||
_, _ = rr.WriteString(`{"httpStatusCode": "200"}`)
|
||||
rr.WriteHeader(http.StatusOK)
|
||||
return rr.Result(), nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPut, "/test/path", http.NoBody)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error creating http request: err:", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal("error submitting request: err:", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
exp := http.MethodPost
|
||||
if got != exp {
|
||||
t.Error("unexpected status: exp:", exp, "got:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireserverTransportPostBody(t *testing.T) {
|
||||
// all PUT and POST requests must have an empty string body
|
||||
t.Parallel()
|
||||
|
||||
bodyIsNil := false
|
||||
client := &http.Client{
|
||||
Transport: &WireserverTransport{
|
||||
Transport: &TestTripper{
|
||||
RoundTripF: func(req *http.Request) (*http.Response, error) {
|
||||
bodyIsNil = req.Body == nil
|
||||
rr := httptest.NewRecorder()
|
||||
_, _ = rr.WriteString(`{"httpStatusCode": "200"}`)
|
||||
rr.WriteHeader(http.StatusOK)
|
||||
return rr.Result(), nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// PUT request
|
||||
req, err := http.NewRequest(http.MethodPut, "/test/path", http.NoBody)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error creating http request: err:", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal("error submitting request: err:", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if bodyIsNil {
|
||||
t.Error("downstream request body to wireserver was nil, but not expected to be")
|
||||
}
|
||||
|
||||
// POST request
|
||||
// nolint:noctx // just a test
|
||||
req, err = http.NewRequest(http.MethodPost, "/test/path", http.NoBody)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error creating http request: err:", err)
|
||||
}
|
||||
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal("error submitting request: err:", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if bodyIsNil {
|
||||
t.Error("downstream request body to wireserver was nil, but not expected to be")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireserverTransportQuery(t *testing.T) {
|
||||
// the query string must have its constituent parts converted to slashes and
|
||||
// appended to the path
|
||||
t.Parallel()
|
||||
|
||||
var got string
|
||||
client := &http.Client{
|
||||
Transport: &WireserverTransport{
|
||||
Transport: &TestTripper{
|
||||
RoundTripF: func(req *http.Request) (*http.Response, error) {
|
||||
got = req.URL.Path
|
||||
rr := httptest.NewRecorder()
|
||||
_, _ = rr.WriteString(`{"httpStatusCode": "200"}`)
|
||||
rr.WriteHeader(http.StatusOK)
|
||||
return rr.Result(), nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// nolint:noctx // just a test
|
||||
req, err := http.NewRequest(http.MethodPut, "/test/path?api-version=1234&foo=bar", http.NoBody)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error creating http request: err:", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal("error submitting request: err:", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
exp := "/machine/plugins/?comp=nmagent&type=test/path/api-version/1234/foo/bar"
|
||||
if got != exp {
|
||||
t.Error("received request differs from expectation: got:", got, "want:", exp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireserverResponse(t *testing.T) {
|
||||
wsRespTests := []struct {
|
||||
name string
|
||||
resp string
|
||||
exp int
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
"{}",
|
||||
0,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"happy path",
|
||||
`{
|
||||
"httpStatusCode": "401"
|
||||
}`,
|
||||
401,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing code",
|
||||
`{
|
||||
"httpStatusCode": ""
|
||||
}`,
|
||||
0,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"other stuff",
|
||||
`{
|
||||
"httpStatusCode": "201",
|
||||
"other": "stuff"
|
||||
}`,
|
||||
201,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"not a string",
|
||||
`{
|
||||
"httpStatusCode": 201,
|
||||
"other": "stuff"
|
||||
}`,
|
||||
0,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"processing",
|
||||
`{
|
||||
"httpStatusCode": "102",
|
||||
"other": "stuff"
|
||||
}`,
|
||||
102,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range wsRespTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var resp WireserverResponse
|
||||
err := json.Unmarshal([]byte(test.resp), &resp)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected unmarshaling error: err:", err)
|
||||
}
|
||||
|
||||
got, err := resp.StatusCode()
|
||||
if err != nil && !test.shouldErr {
|
||||
t.Fatal("unexpected error retrieving status code: err:", err)
|
||||
}
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Fatal("no error received when one was expected")
|
||||
}
|
||||
|
||||
if got != test.exp {
|
||||
t.Error("received incorrect code: got:", got, "want:", test.exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
package nmagent_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-container-networking/nmagent"
|
||||
)
|
||||
|
||||
func TestErrorTemp(t *testing.T) {
|
||||
errorTests := []struct {
|
||||
name string
|
||||
err nmagent.Error
|
||||
shouldTemp bool
|
||||
}{
|
||||
{
|
||||
"regular",
|
||||
nmagent.Error{
|
||||
Code: http.StatusInternalServerError,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"processing",
|
||||
nmagent.Error{
|
||||
Code: http.StatusProcessing,
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"unauthorized",
|
||||
nmagent.Error{
|
||||
Code: http.StatusUnauthorized,
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range errorTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if test.err.Temporary() && !test.shouldTemp {
|
||||
t.Fatal("test was temporary and not expected to be")
|
||||
}
|
||||
|
||||
if !test.err.Temporary() && test.shouldTemp {
|
||||
t.Fatal("test was not temporary but expected to be")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContentErrorNew(t *testing.T) {
|
||||
errTests := []struct {
|
||||
name string
|
||||
body io.Reader
|
||||
limit int64
|
||||
contentType string
|
||||
exp string
|
||||
shouldMakeErr bool
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
strings.NewReader(""),
|
||||
0,
|
||||
"text/plain",
|
||||
"unexpected content type \"text/plain\": body: ",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"happy path",
|
||||
strings.NewReader("random text"),
|
||||
11,
|
||||
"text/plain",
|
||||
"unexpected content type \"text/plain\": body: random text",
|
||||
true,
|
||||
},
|
||||
{
|
||||
// if the body is an octet stream, it's entirely possible that it's
|
||||
// unprintable garbage. This ensures that we just print the length
|
||||
"octets",
|
||||
bytes.NewReader([]byte{0xde, 0xad, 0xbe, 0xef}),
|
||||
4,
|
||||
"application/octet-stream",
|
||||
"unexpected content type \"application/octet-stream\": body length: 4",
|
||||
true,
|
||||
},
|
||||
{
|
||||
// even if the length is wrong, we still want to return as much data as
|
||||
// we can for debugging
|
||||
"wrong len",
|
||||
bytes.NewReader([]byte{0xde, 0xad, 0xbe, 0xef}),
|
||||
8,
|
||||
"application/octet-stream",
|
||||
"unexpected content type \"application/octet-stream\": body length: 4",
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range errTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := nmagent.NewContentError(test.contentType, test.body, test.limit)
|
||||
|
||||
var e nmagent.ContentError
|
||||
wasContentErr := errors.As(err, &e)
|
||||
if !wasContentErr && test.shouldMakeErr {
|
||||
t.Fatalf("error was not a ContentError")
|
||||
}
|
||||
|
||||
if wasContentErr && !test.shouldMakeErr {
|
||||
t.Fatalf("received a ContentError when it was not expected")
|
||||
}
|
||||
|
||||
got := err.Error()
|
||||
if got != test.exp {
|
||||
t.Error("unexpected error message: got:", got, "exp:", test.exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,284 @@
|
|||
package nmagent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/Azure/azure-container-networking/nmagent/internal"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Request represents an abstracted HTTP request, capable of validating itself,
|
||||
// producing a valid Path, Body, and its Method.
|
||||
type Request interface {
|
||||
// Validate should ensure that the request is valid to submit
|
||||
Validate() error
|
||||
|
||||
// Path should produce a URL path, complete with any URL parameters
|
||||
// interpolated
|
||||
Path() string
|
||||
|
||||
// Body produces the HTTP request body necessary to submit the request
|
||||
Body() (io.Reader, error)
|
||||
|
||||
// Method returns the HTTP Method to be used for the request.
|
||||
Method() string
|
||||
}
|
||||
|
||||
var _ Request = &PutNetworkContainerRequest{}
|
||||
|
||||
// PutNetworkContainerRequest is a collection of parameters necessary to create
|
||||
// a new network container
|
||||
type PutNetworkContainerRequest struct {
|
||||
ID string `json:"networkContainerID"` // the id of the network container
|
||||
VNetID string `json:"virtualNetworkID"` // the id of the customer's vnet
|
||||
|
||||
// Version is the new network container version
|
||||
Version uint64 `json:"version"`
|
||||
|
||||
// SubnetName is the name of the delegated subnet. This is used to
|
||||
// authenticate the request. The list of ipv4addresses must be contained in
|
||||
// the subnet's prefix.
|
||||
SubnetName string `json:"subnetName"`
|
||||
|
||||
// IPv4 addresses in the customer virtual network that will be assigned to
|
||||
// the interface.
|
||||
IPv4Addrs []string `json:"ipV4Addresses"`
|
||||
|
||||
Policies []Policy `json:"policies"` // policies applied to the network container
|
||||
|
||||
// VlanID is used to distinguish Network Containers with duplicate customer
|
||||
// addresses. "0" is considered a default value by the API.
|
||||
VlanID int `json:"vlanId"`
|
||||
|
||||
GREKey uint16 `json:"greKey"`
|
||||
|
||||
// AuthenticationToken is the base64 security token for the subnet containing
|
||||
// the Network Container addresses
|
||||
AuthenticationToken string `json:"-"`
|
||||
|
||||
// PrimaryAddress is the primary customer address of the interface in the
|
||||
// management VNet
|
||||
PrimaryAddress string `json:"-"`
|
||||
}
|
||||
|
||||
// Body marshals the JSON fields of the request and produces an Reader intended
|
||||
// for use with an HTTP request
|
||||
func (p *PutNetworkContainerRequest) Body() (io.Reader, error) {
|
||||
body, err := json.Marshal(p)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshaling PutNetworkContainerRequest")
|
||||
}
|
||||
|
||||
return bytes.NewReader(body), nil
|
||||
}
|
||||
|
||||
// Method returns the HTTP method for this request type
|
||||
func (p *PutNetworkContainerRequest) Method() string {
|
||||
return http.MethodPost
|
||||
}
|
||||
|
||||
// Path returns the URL path necessary to submit this PutNetworkContainerRequest
|
||||
func (p *PutNetworkContainerRequest) Path() string {
|
||||
const PutNCRequestPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1"
|
||||
return fmt.Sprintf(PutNCRequestPath, p.PrimaryAddress, p.ID, p.AuthenticationToken)
|
||||
}
|
||||
|
||||
// Validate ensures that all of the required parameters of the request have
|
||||
// been filled out properly prior to submission to NMAgent
|
||||
func (p *PutNetworkContainerRequest) Validate() error {
|
||||
err := internal.ValidationError{}
|
||||
|
||||
if p.Version == 0 {
|
||||
err.MissingFields = append(err.MissingFields, "Version")
|
||||
}
|
||||
|
||||
if p.SubnetName == "" {
|
||||
err.MissingFields = append(err.MissingFields, "SubnetName")
|
||||
}
|
||||
|
||||
if len(p.IPv4Addrs) == 0 {
|
||||
err.MissingFields = append(err.MissingFields, "IPv4Addrs")
|
||||
}
|
||||
|
||||
if p.VNetID == "" {
|
||||
err.MissingFields = append(err.MissingFields, "VNetID")
|
||||
}
|
||||
|
||||
if err.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type Policy struct {
|
||||
ID string
|
||||
Type string
|
||||
}
|
||||
|
||||
// MarshalJson encodes policies as a JSON string, separated by a comma. This
|
||||
// specific format is requested by the NMAgent documentation
|
||||
func (p Policy) MarshalJSON() ([]byte, error) {
|
||||
out := bytes.NewBufferString(p.ID)
|
||||
out.WriteString(", ")
|
||||
out.WriteString(p.Type)
|
||||
|
||||
outStr := out.String()
|
||||
// nolint:wrapcheck // wrapping this error provides no useful information
|
||||
return json.Marshal(outStr)
|
||||
}
|
||||
|
||||
// UnmarshalJSON decodes a JSON-encoded policy string
|
||||
func (p *Policy) UnmarshalJSON(in []byte) error {
|
||||
const expectedNumParts = 2
|
||||
|
||||
var raw string
|
||||
err := json.Unmarshal(in, &raw)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "decoding policy")
|
||||
}
|
||||
|
||||
parts := strings.Split(raw, ",")
|
||||
if len(parts) != expectedNumParts {
|
||||
return errors.New("policies must be two comma-separated values")
|
||||
}
|
||||
|
||||
p.ID = strings.TrimFunc(parts[0], unicode.IsSpace)
|
||||
p.Type = strings.TrimFunc(parts[1], unicode.IsSpace)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ Request = JoinNetworkRequest{}
|
||||
|
||||
type JoinNetworkRequest struct {
|
||||
NetworkID string `validate:"presence" json:"-"` // the customer's VNet ID
|
||||
}
|
||||
|
||||
// Path constructs a URL path for invoking a JoinNetworkRequest using the
|
||||
// provided parameters
|
||||
func (j JoinNetworkRequest) Path() string {
|
||||
const JoinNetworkPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1"
|
||||
return fmt.Sprintf(JoinNetworkPath, j.NetworkID)
|
||||
}
|
||||
|
||||
// Body returns nothing, because JoinNetworkRequest has no request body
|
||||
func (j JoinNetworkRequest) Body() (io.Reader, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Method returns the HTTP request method to submit a JoinNetworkRequest
|
||||
func (j JoinNetworkRequest) Method() string {
|
||||
return http.MethodPost
|
||||
}
|
||||
|
||||
// Validate ensures that the provided parameters of the request are valid
|
||||
func (j JoinNetworkRequest) Validate() error {
|
||||
err := internal.ValidationError{}
|
||||
|
||||
if j.NetworkID == "" {
|
||||
err.MissingFields = append(err.MissingFields, "NetworkID")
|
||||
}
|
||||
|
||||
if err.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var _ Request = DeleteContainerRequest{}
|
||||
|
||||
// DeleteContainerRequest represents all information necessary to request that
|
||||
// NMAgent delete a particular network container
|
||||
type DeleteContainerRequest struct {
|
||||
NCID string `json:"-"` // the Network Container ID
|
||||
|
||||
// PrimaryAddress is the primary customer address of the interface in the
|
||||
// management VNET
|
||||
PrimaryAddress string `json:"-"`
|
||||
AuthenticationToken string `json:"-"`
|
||||
}
|
||||
|
||||
// Path returns the path for submitting a DeleteContainerRequest with
|
||||
// parameters interpolated correctly
|
||||
func (d DeleteContainerRequest) Path() string {
|
||||
const DeleteNCPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1/method/DELETE"
|
||||
return fmt.Sprintf(DeleteNCPath, d.PrimaryAddress, d.NCID, d.AuthenticationToken)
|
||||
}
|
||||
|
||||
// Body returns nothing, because DeleteContainerRequests have no HTTP body
|
||||
func (d DeleteContainerRequest) Body() (io.Reader, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Method returns the HTTP method required to submit a DeleteContainerRequest
|
||||
func (d DeleteContainerRequest) Method() string {
|
||||
return http.MethodPost
|
||||
}
|
||||
|
||||
// Validate ensures that the DeleteContainerRequest has the correct information
|
||||
// to submit the request
|
||||
func (d DeleteContainerRequest) Validate() error {
|
||||
err := internal.ValidationError{}
|
||||
|
||||
if d.NCID == "" {
|
||||
err.MissingFields = append(err.MissingFields, "NCID")
|
||||
}
|
||||
|
||||
if d.PrimaryAddress == "" {
|
||||
err.MissingFields = append(err.MissingFields, "PrimaryAddress")
|
||||
}
|
||||
|
||||
if d.AuthenticationToken == "" {
|
||||
err.MissingFields = append(err.MissingFields, "AuthenticationToken")
|
||||
}
|
||||
|
||||
if err.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var _ Request = GetNetworkConfigRequest{}
|
||||
|
||||
// GetNetworkConfigRequest is a collection of necessary information for
|
||||
// submitting a request for a customer's network configuration
|
||||
type GetNetworkConfigRequest struct {
|
||||
VNetID string `json:"-"` // the customer's virtual network ID
|
||||
}
|
||||
|
||||
// Path produces a URL path used to submit a request
|
||||
func (g GetNetworkConfigRequest) Path() string {
|
||||
const GetNetworkConfigPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1"
|
||||
return fmt.Sprintf(GetNetworkConfigPath, g.VNetID)
|
||||
}
|
||||
|
||||
// Body returns nothing because GetNetworkConfigRequest has no HTTP request
|
||||
// body
|
||||
func (g GetNetworkConfigRequest) Body() (io.Reader, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Method returns the HTTP method required to submit a GetNetworkConfigRequest
|
||||
func (g GetNetworkConfigRequest) Method() string {
|
||||
return http.MethodGet
|
||||
}
|
||||
|
||||
// Validate ensures that the request is complete and the parameters are correct
|
||||
func (g GetNetworkConfigRequest) Validate() error {
|
||||
err := internal.ValidationError{}
|
||||
|
||||
if g.VNetID == "" {
|
||||
err.MissingFields = append(err.MissingFields, "VNetID")
|
||||
}
|
||||
|
||||
if err.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,402 @@
|
|||
package nmagent_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-container-networking/nmagent"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestPolicyMarshal(t *testing.T) {
|
||||
policyTests := []struct {
|
||||
name string
|
||||
policy nmagent.Policy
|
||||
exp string
|
||||
}{
|
||||
{
|
||||
"basic",
|
||||
nmagent.Policy{
|
||||
ID: "policyID1",
|
||||
Type: "type1",
|
||||
},
|
||||
"\"policyID1, type1\"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range policyTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := json.Marshal(test.policy)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected err marshaling policy: err", err)
|
||||
}
|
||||
|
||||
if string(got) != test.exp {
|
||||
t.Errorf("marshaled policy does not match expectation: got: %q: exp: %q", string(got), test.exp)
|
||||
}
|
||||
|
||||
var enc nmagent.Policy
|
||||
err = json.Unmarshal(got, &enc)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error unmarshaling: err:", err)
|
||||
}
|
||||
|
||||
if !cmp.Equal(enc, test.policy) {
|
||||
t.Error("re-encoded policy differs from expectation: diff:", cmp.Diff(enc, test.policy))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteContainerRequestValidation(t *testing.T) {
|
||||
dcrTests := []struct {
|
||||
name string
|
||||
req nmagent.DeleteContainerRequest
|
||||
shouldBeValid bool
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
nmagent.DeleteContainerRequest{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing ncid",
|
||||
nmagent.DeleteContainerRequest{
|
||||
PrimaryAddress: "10.0.0.1",
|
||||
AuthenticationToken: "swordfish",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing primary address",
|
||||
nmagent.DeleteContainerRequest{
|
||||
NCID: "00000000-0000-0000-0000-000000000000",
|
||||
AuthenticationToken: "swordfish",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing auth token",
|
||||
nmagent.DeleteContainerRequest{
|
||||
NCID: "00000000-0000-0000-0000-000000000000",
|
||||
PrimaryAddress: "10.0.0.1",
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range dcrTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := test.req.Validate()
|
||||
if err != nil && test.shouldBeValid {
|
||||
t.Fatal("unexpected validation errors: err:", err)
|
||||
}
|
||||
|
||||
if err == nil && !test.shouldBeValid {
|
||||
t.Fatal("expected request to be invalid but wasn't")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinNetworkRequestPath(t *testing.T) {
|
||||
jnr := nmagent.JoinNetworkRequest{
|
||||
NetworkID: "00000000-0000-0000-0000-000000000000",
|
||||
}
|
||||
|
||||
exp := "/NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1"
|
||||
if jnr.Path() != exp {
|
||||
t.Error("unexpected path: exp:", exp, "got:", jnr.Path())
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinNetworkRequestValidate(t *testing.T) {
|
||||
validateRequest := []struct {
|
||||
name string
|
||||
req nmagent.JoinNetworkRequest
|
||||
shouldBeValid bool
|
||||
}{
|
||||
{
|
||||
"invalid",
|
||||
nmagent.JoinNetworkRequest{
|
||||
NetworkID: "",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"valid",
|
||||
nmagent.JoinNetworkRequest{
|
||||
NetworkID: "00000000-0000-0000-0000-000000000000",
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range validateRequest {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := test.req.Validate()
|
||||
if err != nil && test.shouldBeValid {
|
||||
t.Fatal("unexpected error validating: err:", err)
|
||||
}
|
||||
|
||||
if err == nil && !test.shouldBeValid {
|
||||
t.Fatal("expected request to be invalid but wasn't")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNetworkConfigRequestPath(t *testing.T) {
|
||||
pathTests := []struct {
|
||||
name string
|
||||
req nmagent.GetNetworkConfigRequest
|
||||
exp string
|
||||
}{
|
||||
{
|
||||
"happy path",
|
||||
nmagent.GetNetworkConfigRequest{
|
||||
VNetID: "00000000-0000-0000-0000-000000000000",
|
||||
},
|
||||
"/NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range pathTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := test.req.Path(); got != test.exp {
|
||||
t.Error("unexpected path: exp:", test.exp, "got:", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNetworkConfigRequestValidate(t *testing.T) {
|
||||
validateTests := []struct {
|
||||
name string
|
||||
req nmagent.GetNetworkConfigRequest
|
||||
shouldBeValid bool
|
||||
}{
|
||||
{
|
||||
"happy path",
|
||||
nmagent.GetNetworkConfigRequest{
|
||||
VNetID: "00000000-0000-0000-0000-000000000000",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
nmagent.GetNetworkConfigRequest{},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range validateTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := test.req.Validate()
|
||||
if err != nil && test.shouldBeValid {
|
||||
t.Fatal("expected request to be valid but wasn't: err:", err)
|
||||
}
|
||||
|
||||
if err == nil && !test.shouldBeValid {
|
||||
t.Fatal("expected error to be invalid but wasn't")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutNetworkContainerRequestPath(t *testing.T) {
|
||||
pathTests := []struct {
|
||||
name string
|
||||
req nmagent.PutNetworkContainerRequest
|
||||
exp string
|
||||
}{
|
||||
{
|
||||
"happy path",
|
||||
nmagent.PutNetworkContainerRequest{
|
||||
ID: "00000000-0000-0000-0000-000000000000",
|
||||
VNetID: "11111111-1111-1111-1111-111111111111",
|
||||
Version: uint64(12345),
|
||||
SubnetName: "foo",
|
||||
IPv4Addrs: []string{
|
||||
"10.0.0.2",
|
||||
"10.0.0.3",
|
||||
},
|
||||
Policies: []nmagent.Policy{
|
||||
{
|
||||
ID: "Foo",
|
||||
Type: "Bar",
|
||||
},
|
||||
},
|
||||
VlanID: 0,
|
||||
AuthenticationToken: "swordfish",
|
||||
PrimaryAddress: "10.0.0.1",
|
||||
},
|
||||
"/NetworkManagement/interfaces/10.0.0.1/networkContainers/00000000-0000-0000-0000-000000000000/authenticationToken/swordfish/api-version/1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range pathTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := test.req.Path(); got != test.exp {
|
||||
t.Error("path differs from expectation: exp:", test.exp, "got:", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutNetworkContainerRequestValidate(t *testing.T) {
|
||||
validationTests := []struct {
|
||||
name string
|
||||
req nmagent.PutNetworkContainerRequest
|
||||
shouldBeValid bool
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
nmagent.PutNetworkContainerRequest{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"happy",
|
||||
nmagent.PutNetworkContainerRequest{
|
||||
ID: "00000000-0000-0000-0000-000000000000",
|
||||
VNetID: "11111111-1111-1111-1111-111111111111",
|
||||
Version: uint64(12345),
|
||||
SubnetName: "foo",
|
||||
IPv4Addrs: []string{
|
||||
"10.0.0.2",
|
||||
"10.0.0.3",
|
||||
},
|
||||
Policies: []nmagent.Policy{
|
||||
{
|
||||
ID: "Foo",
|
||||
Type: "Bar",
|
||||
},
|
||||
},
|
||||
VlanID: 0,
|
||||
AuthenticationToken: "swordfish",
|
||||
PrimaryAddress: "10.0.0.1",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"missing IPv4Addrs",
|
||||
nmagent.PutNetworkContainerRequest{
|
||||
ID: "00000000-0000-0000-0000-000000000000",
|
||||
VNetID: "11111111-1111-1111-1111-111111111111",
|
||||
Version: uint64(12345),
|
||||
SubnetName: "foo",
|
||||
IPv4Addrs: []string{}, // the important part
|
||||
Policies: []nmagent.Policy{
|
||||
{
|
||||
ID: "Foo",
|
||||
Type: "Bar",
|
||||
},
|
||||
},
|
||||
VlanID: 0,
|
||||
AuthenticationToken: "swordfish",
|
||||
PrimaryAddress: "10.0.0.1",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing subnet name",
|
||||
nmagent.PutNetworkContainerRequest{
|
||||
ID: "00000000-0000-0000-0000-000000000000",
|
||||
VNetID: "11111111-1111-1111-1111-111111111111",
|
||||
Version: uint64(12345),
|
||||
SubnetName: "", // the important part of the test
|
||||
IPv4Addrs: []string{
|
||||
"10.0.0.2",
|
||||
},
|
||||
Policies: []nmagent.Policy{
|
||||
{
|
||||
ID: "Foo",
|
||||
Type: "Bar",
|
||||
},
|
||||
},
|
||||
VlanID: 0,
|
||||
AuthenticationToken: "swordfish",
|
||||
PrimaryAddress: "10.0.0.1",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing version",
|
||||
nmagent.PutNetworkContainerRequest{
|
||||
ID: "00000000-0000-0000-0000-000000000000",
|
||||
VNetID: "11111111-1111-1111-1111-111111111111",
|
||||
Version: uint64(0), // the important part of the test
|
||||
SubnetName: "foo",
|
||||
IPv4Addrs: []string{
|
||||
"10.0.0.2",
|
||||
},
|
||||
Policies: []nmagent.Policy{
|
||||
{
|
||||
ID: "Foo",
|
||||
Type: "Bar",
|
||||
},
|
||||
},
|
||||
VlanID: 0,
|
||||
AuthenticationToken: "swordfish",
|
||||
PrimaryAddress: "10.0.0.1",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing vnet id",
|
||||
nmagent.PutNetworkContainerRequest{
|
||||
ID: "00000000-0000-0000-0000-000000000000",
|
||||
VNetID: "", // the important part
|
||||
Version: uint64(12345),
|
||||
SubnetName: "foo",
|
||||
IPv4Addrs: []string{
|
||||
"10.0.0.2",
|
||||
},
|
||||
Policies: []nmagent.Policy{
|
||||
{
|
||||
ID: "Foo",
|
||||
Type: "Bar",
|
||||
},
|
||||
},
|
||||
VlanID: 0,
|
||||
AuthenticationToken: "swordfish",
|
||||
PrimaryAddress: "10.0.0.1",
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range validationTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := test.req.Validate()
|
||||
if err != nil && test.shouldBeValid {
|
||||
t.Fatal("unexpected error validating: err:", err)
|
||||
}
|
||||
|
||||
if err == nil && !test.shouldBeValid {
|
||||
t.Fatal("expected validation error but received none")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package nmagent
|
||||
|
||||
type VirtualNetwork struct {
|
||||
CNetSpace string `json:"cnetSpace"`
|
||||
DefaultGateway string `json:"defaultGateway"`
|
||||
DNSServers []string `json:"dnsServers"`
|
||||
Subnets []Subnet `json:"subnets"`
|
||||
VNetSpace string `json:"vnetSpace"`
|
||||
VNetVersion string `json:"vnetVersion"`
|
||||
}
|
||||
|
||||
type Subnet struct {
|
||||
AddressPrefix string `json:"addressPrefix"`
|
||||
SubnetName string `json:"subnetName"`
|
||||
Tags []Tag `json:"tags"`
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // the type of the tag (e.g. "System" or "Custom")
|
||||
}
|
Загрузка…
Ссылка в новой задаче