diff --git a/common/ioshim_windows.go b/common/ioshim_windows.go index 323ffff20..6b6e31a21 100644 --- a/common/ioshim_windows.go +++ b/common/ioshim_windows.go @@ -1,6 +1,8 @@ package common import ( + "testing" + "github.com/Azure/azure-container-networking/network/hnswrapper" testutils "github.com/Azure/azure-container-networking/test/utils" utilexec "k8s.io/utils/exec" @@ -21,6 +23,16 @@ func NewIOShim() *IOShim { func NewMockIOShim(calls []testutils.TestCmd) *IOShim { return &IOShim{ Exec: testutils.GetFakeExecWithScripts(calls), - Hns: &hnswrapper.Hnsv2wrapperFake{}, + Hns: hnswrapper.NewHnsv2wrapperFake(), } } + +func NewMockIOShimWithFakeHNS(hns *hnswrapper.Hnsv2wrapperFake) *IOShim { + return &IOShim{ + Exec: testutils.GetFakeExecWithScripts([]testutils.TestCmd{}), + Hns: hns, + } +} + +// VerifyCalls is used for Unit Testing of linux. In windows this is no-op +func (ioshim *IOShim) VerifyCalls(_ *testing.T, _ []testutils.TestCmd) {} diff --git a/network/hnswrapper/hnsv2wrapperfake.go b/network/hnswrapper/hnsv2wrapperfake.go index a368eb330..05233bfd0 100644 --- a/network/hnswrapper/hnsv2wrapperfake.go +++ b/network/hnswrapper/hnsv2wrapperfake.go @@ -7,13 +7,44 @@ package hnswrapper import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "github.com/Microsoft/hcsshim/hcn" ) +const networkName = "azure" + +var errorFakeHNS = errors.New("errorFakeHNS Error") + +func newErrorFakeHNS(errStr string) error { + return fmt.Errorf("%w : %s", errorFakeHNS, errStr) +} + type Hnsv2wrapperFake struct { + Cache FakeHNSCache + *sync.Mutex +} + +func NewHnsv2wrapperFake() *Hnsv2wrapperFake { + return &Hnsv2wrapperFake{ + Mutex: &sync.Mutex{}, + Cache: FakeHNSCache{ + networks: map[string]*FakeHostComputeNetwork{}, + endpoints: map[string]*FakeHostComputeEndpoint{}, + }, + } } func (f Hnsv2wrapperFake) CreateNetwork(network *hcn.HostComputeNetwork) (*hcn.HostComputeNetwork, error) { + f.Lock() + defer f.Unlock() + + f.Cache.networks[network.Name] = NewFakeHostComputeNetwork(network) return network, nil } @@ -21,7 +52,104 @@ func (f Hnsv2wrapperFake) DeleteNetwork(network *hcn.HostComputeNetwork) error { return nil } -func (Hnsv2wrapperFake) ModifyNetworkSettings(network *hcn.HostComputeNetwork, request *hcn.ModifyNetworkSettingRequest) error { +func (f Hnsv2wrapperFake) ModifyNetworkSettings(network *hcn.HostComputeNetwork, request *hcn.ModifyNetworkSettingRequest) error { + f.Lock() + defer f.Unlock() + networkCache, ok := f.Cache.networks[network.Name] + if !ok { + return nil + } + switch request.RequestType { + case hcn.RequestTypeAdd: + var setPolSettings hcn.PolicyNetworkRequest + err := json.Unmarshal(request.Settings, &setPolSettings) + if err != nil { + return newErrorFakeHNS(err.Error()) + } + for _, setPolSetting := range setPolSettings.Policies { + if setPolSetting.Type == hcn.SetPolicy { + var setpol hcn.SetPolicySetting + err := json.Unmarshal(setPolSetting.Settings, &setpol) + if err != nil { + return newErrorFakeHNS(err.Error()) + } + if setpol.PolicyType != hcn.SetPolicyTypeIpSet { + // Check Nested SetPolicy members + members := strings.Split(setpol.Values, ",") + for _, memberID := range members { + _, ok := networkCache.Policies[memberID] + if !ok { + return newErrorFakeHNS(fmt.Sprintf("Member Policy %s not found", memberID)) + } + } + } + networkCache.Policies[setpol.Id] = &setpol + } + } + case hcn.RequestTypeRemove: + var setPolSettings hcn.PolicyNetworkRequest + err := json.Unmarshal(request.Settings, &setPolSettings) + if err != nil { + return newErrorFakeHNS(err.Error()) + } + for _, newPolicy := range setPolSettings.Policies { + var setpol hcn.SetPolicySetting + err := json.Unmarshal(newPolicy.Settings, &setpol) + if err != nil { + return newErrorFakeHNS(err.Error()) + } + if _, ok := networkCache.Policies[setpol.Id]; !ok { + return newErrorFakeHNS(fmt.Sprintf("[FakeHNS] could not find %s ipset", setpol.Name)) + } + if setpol.PolicyType == hcn.SetPolicyTypeIpSet { + // For 1st level sets check if they are being referred by nested sets + for _, cacheSet := range networkCache.Policies { + if cacheSet.PolicyType == hcn.SetPolicyTypeIpSet { + continue + } + if strings.Contains(cacheSet.Values, setpol.Id) { + return newErrorFakeHNS(fmt.Sprintf("Set %s is being referred by another %s set", setpol.Name, cacheSet.Name)) + } + } + } + delete(networkCache.Policies, setpol.Id) + } + case hcn.RequestTypeUpdate: + var setPolSettings hcn.PolicyNetworkRequest + err := json.Unmarshal(request.Settings, &setPolSettings) + if err != nil { + return newErrorFakeHNS(err.Error()) + } + for _, newPolicy := range setPolSettings.Policies { + var setpol hcn.SetPolicySetting + err := json.Unmarshal(newPolicy.Settings, &setpol) + if err != nil { + return newErrorFakeHNS(err.Error()) + } + if _, ok := networkCache.Policies[setpol.Id]; !ok { + return newErrorFakeHNS(fmt.Sprintf("[FakeHNS] could not find %s ipset", setpol.Name)) + } + _, ok := networkCache.Policies[setpol.Id] + if !ok { + // Replicating HNS behavior, we will not update non-existent set policy + continue + } + if setpol.PolicyType != hcn.SetPolicyTypeIpSet { + // Check Nested SetPolicy members + members := strings.Split(setpol.Values, ",") + for _, memberID := range members { + _, ok := networkCache.Policies[memberID] + if !ok { + return newErrorFakeHNS(fmt.Sprintf("Member Policy %s not found", memberID)) + } + } + } + networkCache.Policies[setpol.Id] = &setpol + } + case hcn.RequestTypeRefresh: + return nil + } + return nil } @@ -33,25 +161,46 @@ func (Hnsv2wrapperFake) RemoveNetworkPolicy(network *hcn.HostComputeNetwork, net return nil } -func (Hnsv2wrapperFake) GetNetworkByName(networkName string) (*hcn.HostComputeNetwork, error) { +func (f Hnsv2wrapperFake) GetNetworkByName(networkName string) (*hcn.HostComputeNetwork, error) { + f.Lock() + defer f.Unlock() + if network, ok := f.Cache.networks[networkName]; ok { + return network.GetHCNObj(), nil + } return &hcn.HostComputeNetwork{}, nil } func (f Hnsv2wrapperFake) GetNetworkByID(networkID string) (*hcn.HostComputeNetwork, error) { - network := &hcn.HostComputeNetwork{Id: networkID} - return network, nil + f.Lock() + defer f.Unlock() + for _, network := range f.Cache.networks { + if network.ID == networkID { + return network.GetHCNObj(), nil + } + } + return &hcn.HostComputeNetwork{}, nil } func (f Hnsv2wrapperFake) GetEndpointByID(endpointID string) (*hcn.HostComputeEndpoint, error) { - endpoint := &hcn.HostComputeEndpoint{Id: endpointID} + f.Lock() + defer f.Unlock() + if ep, ok := f.Cache.endpoints[endpointID]; ok { + return ep.GetHCNObj(), nil + } + return &hcn.HostComputeEndpoint{}, nil +} + +func (f Hnsv2wrapperFake) CreateEndpoint(endpoint *hcn.HostComputeEndpoint) (*hcn.HostComputeEndpoint, error) { + f.Lock() + defer f.Unlock() + f.Cache.endpoints[endpoint.Id] = NewFakeHostComputeEndpoint(endpoint) return endpoint, nil } -func (Hnsv2wrapperFake) CreateEndpoint(endpoint *hcn.HostComputeEndpoint) (*hcn.HostComputeEndpoint, error) { - return endpoint, nil -} - -func (Hnsv2wrapperFake) DeleteEndpoint(endpoint *hcn.HostComputeEndpoint) error { +func (f Hnsv2wrapperFake) DeleteEndpoint(endpoint *hcn.HostComputeEndpoint) error { + f.Lock() + defer f.Unlock() + delete(f.Cache.endpoints, endpoint.Id) return nil } @@ -68,10 +217,188 @@ func (Hnsv2wrapperFake) RemoveNamespaceEndpoint(namespaceId string, endpointId s return nil } -func (Hnsv2wrapperFake) ListEndpointsOfNetwork(networkId string) ([]hcn.HostComputeEndpoint, error) { - return []hcn.HostComputeEndpoint{}, nil +func (f Hnsv2wrapperFake) ListEndpointsOfNetwork(networkId string) ([]hcn.HostComputeEndpoint, error) { + f.Lock() + defer f.Unlock() + endpoints := make([]hcn.HostComputeEndpoint, 0) + for _, endpoint := range f.Cache.endpoints { + if endpoint.HostComputeNetwork == networkId { + endpoints = append(endpoints, *endpoint.GetHCNObj()) + } + } + return endpoints, nil } -func (Hnsv2wrapperFake) ApplyEndpointPolicy(endpoint *hcn.HostComputeEndpoint, requestType hcn.RequestType, endpointPolicy hcn.PolicyEndpointRequest) error { +func (f Hnsv2wrapperFake) ApplyEndpointPolicy(endpoint *hcn.HostComputeEndpoint, requestType hcn.RequestType, endpointPolicy hcn.PolicyEndpointRequest) error { + f.Lock() + defer f.Unlock() + + epCache, ok := f.Cache.endpoints[endpoint.Id] + if !ok { + return newErrorFakeHNS(fmt.Sprintf("[FakeHNS] could not find endpoint %s", endpoint.Id)) + } + switch requestType { + case hcn.RequestTypeAdd: + for _, newPolicy := range endpointPolicy.Policies { + if newPolicy.Type != hcn.ACL { + continue + } + var aclPol FakeEndpointPolicy + err := json.Unmarshal(newPolicy.Settings, &aclPol) + if err != nil { + return newErrorFakeHNS(err.Error()) + } + epCache.Policies = append(epCache.Policies, &aclPol) + } + case hcn.RequestTypeRemove: + for _, newPolicy := range endpointPolicy.Policies { + if newPolicy.Type != hcn.ACL { + continue + } + var aclPol FakeEndpointPolicy + err := json.Unmarshal(newPolicy.Settings, &aclPol) + if err != nil { + return newErrorFakeHNS(err.Error()) + } + err = epCache.RemovePolicy(&aclPol) + if err != nil { + return err + } + } + case hcn.RequestTypeUpdate: + epCache.Policies = make([]*FakeEndpointPolicy, 0) + for _, newPolicy := range endpointPolicy.Policies { + if newPolicy.Type != hcn.ACL { + continue + } + var aclPol FakeEndpointPolicy + err := json.Unmarshal(newPolicy.Settings, &aclPol) + if err != nil { + return newErrorFakeHNS(err.Error()) + } + epCache.Policies = append(epCache.Policies, &aclPol) + } + case hcn.RequestTypeRefresh: + return nil + } + return nil } + +type FakeHNSCache struct { + networks map[string]*FakeHostComputeNetwork + endpoints map[string]*FakeHostComputeEndpoint +} + +func (fCache FakeHNSCache) SetPolicy(setID string) *hcn.SetPolicySetting { + for _, network := range fCache.networks { + for _, policy := range network.Policies { + if policy.Id == setID { + return policy + } + } + } + return nil +} + +func (fCache FakeHNSCache) ACLPolicies(epList map[string]string, policyID string) (map[string][]*FakeEndpointPolicy, error) { + aclPols := make(map[string][]*FakeEndpointPolicy) + for ip, epID := range epList { + epCache, ok := fCache.endpoints[epID] + if !ok { + return nil, newErrorFakeHNS(fmt.Sprintf("[FakeHNS] could not find endpoint %s", epID)) + } + if epCache.IPConfiguration != ip { + return nil, newErrorFakeHNS(fmt.Sprintf("[FakeHNS] Mismatch in IP addr of endpoint %s Got: %s, Expect %s", + epID, epCache.IPConfiguration, ip)) + } + aclPols[epID] = make([]*FakeEndpointPolicy, 0) + for _, policy := range epCache.Policies { + if policy.ID == policyID { + aclPols[epID] = append(aclPols[epID], policy) + } + } + + } + return aclPols, nil +} + +func (fCache FakeHNSCache) GetAllACLs() map[string][]*FakeEndpointPolicy { + aclPols := make(map[string][]*FakeEndpointPolicy) + for _, ep := range fCache.endpoints { + aclPols[ep.ID] = ep.Policies + } + return aclPols +} + +type FakeHostComputeNetwork struct { + ID string + Name string + Policies map[string]*hcn.SetPolicySetting +} + +func NewFakeHostComputeNetwork(network *hcn.HostComputeNetwork) *FakeHostComputeNetwork { + return &FakeHostComputeNetwork{ + ID: network.Id, + Name: network.Name, + Policies: make(map[string]*hcn.SetPolicySetting), + } +} + +func (fNetwork *FakeHostComputeNetwork) GetHCNObj() *hcn.HostComputeNetwork { + return &hcn.HostComputeNetwork{ + Id: fNetwork.ID, + Name: fNetwork.Name, + } +} + +type FakeHostComputeEndpoint struct { + ID string + Name string + HostComputeNetwork string + Policies []*FakeEndpointPolicy + IPConfiguration string +} + +func NewFakeHostComputeEndpoint(endpoint *hcn.HostComputeEndpoint) *FakeHostComputeEndpoint { + ip := "" + if endpoint.IpConfigurations != nil { + ip = endpoint.IpConfigurations[0].IpAddress + } + return &FakeHostComputeEndpoint{ + ID: endpoint.Id, + Name: endpoint.Name, + HostComputeNetwork: endpoint.HostComputeNetwork, + IPConfiguration: ip, + } +} + +func (fEndpoint *FakeHostComputeEndpoint) GetHCNObj() *hcn.HostComputeEndpoint { + return &hcn.HostComputeEndpoint{ + Id: fEndpoint.ID, + Name: fEndpoint.Name, + HostComputeNetwork: fEndpoint.HostComputeNetwork, + } +} + +func (fEndpoint *FakeHostComputeEndpoint) RemovePolicy(toRemovePol *FakeEndpointPolicy) error { + for i, policy := range fEndpoint.Policies { + if reflect.DeepEqual(policy, toRemovePol) { + fEndpoint.Policies = append(fEndpoint.Policies[:i], fEndpoint.Policies[i+1:]...) + return nil + } + } + return newErrorFakeHNS(fmt.Sprintf("Could not find policy %+v", toRemovePol)) +} + +type FakeEndpointPolicy struct { + ID string `json:",omitempty"` + Protocols string `json:",omitempty"` // EX: 6 (TCP), 17 (UDP), 1 (ICMPv4), 58 (ICMPv6), 2 (IGMP) + Action hcn.ActionType `json:","` + Direction hcn.DirectionType `json:","` + LocalAddresses string `json:",omitempty"` + RemoteAddresses string `json:",omitempty"` + LocalPorts string `json:",omitempty"` + RemotePorts string `json:",omitempty"` + Priority int `json:",omitempty"` +} diff --git a/npm/pkg/dataplane/ipsets/ipsetmanager.go b/npm/pkg/dataplane/ipsets/ipsetmanager.go index 13b6e1870..f2d5b10bf 100644 --- a/npm/pkg/dataplane/ipsets/ipsetmanager.go +++ b/npm/pkg/dataplane/ipsets/ipsetmanager.go @@ -69,6 +69,7 @@ func (iMgr *IPSetManager) CreateIPSets(setMetadatas []*IPSetMetadata) { } func (iMgr *IPSetManager) createIPSet(setMetadata *IPSetMetadata) { + // TODO (vamsi) check for os specific restrictions on ipsets prefixedName := setMetadata.GetPrefixName() if iMgr.exists(prefixedName) { return diff --git a/npm/pkg/dataplane/ipsets/ipsetmanager_test.go b/npm/pkg/dataplane/ipsets/ipsetmanager_test.go index 50108c4fc..7b919642d 100644 --- a/npm/pkg/dataplane/ipsets/ipsetmanager_test.go +++ b/npm/pkg/dataplane/ipsets/ipsetmanager_test.go @@ -19,10 +19,17 @@ const ( testPodIP = "10.0.0.0" ) -var iMgrApplyOnNeedCfg = &IPSetManagerCfg{ - IPSetMode: ApplyOnNeed, - NetworkName: "azure", -} +var ( + iMgrApplyOnNeedCfg = &IPSetManagerCfg{ + IPSetMode: ApplyOnNeed, + NetworkName: "azure", + } + + iMgrApplyAlwaysCfg = &IPSetManagerCfg{ + IPSetMode: ApplyAllIPSets, + NetworkName: "azure", + } +) func TestCreateIPSet(t *testing.T) { iMgr := NewIPSetManager(iMgrApplyOnNeedCfg, common.NewMockIOShim([]testutils.TestCmd{})) @@ -40,6 +47,22 @@ func TestCreateIPSet(t *testing.T) { assert.Equal(t, util.GetHashedName(setMetadata.GetPrefixName()), set.HashedName) } +func TestCreateIPSetApplyAlways(t *testing.T) { + iMgr := NewIPSetManager(iMgrApplyAlwaysCfg, common.NewMockIOShim([]testutils.TestCmd{})) + + setMetadata := NewIPSetMetadata(testSetName, Namespace) + iMgr.CreateIPSets([]*IPSetMetadata{setMetadata}) + // creating twice + iMgr.CreateIPSets([]*IPSetMetadata{setMetadata}) + + assert.True(t, iMgr.exists(setMetadata.GetPrefixName())) + + set := iMgr.GetIPSet(setMetadata.GetPrefixName()) + require.NotNil(t, set) + assert.Equal(t, setMetadata.GetPrefixName(), set.Name) + assert.Equal(t, util.GetHashedName(setMetadata.GetPrefixName()), set.HashedName) +} + func TestAddToSet(t *testing.T) { iMgr := NewIPSetManager(iMgrApplyOnNeedCfg, common.NewMockIOShim([]testutils.TestCmd{})) diff --git a/npm/pkg/dataplane/ipsets/ipsetmanager_windows.go b/npm/pkg/dataplane/ipsets/ipsetmanager_windows.go index 19d36ec37..ea601e500 100644 --- a/npm/pkg/dataplane/ipsets/ipsetmanager_windows.go +++ b/npm/pkg/dataplane/ipsets/ipsetmanager_windows.go @@ -172,27 +172,43 @@ func (iMgr *IPSetManager) getHCnNetwork() (*hcn.HostComputeNetwork, error) { func (iMgr *IPSetManager) modifySetPolicies(network *hcn.HostComputeNetwork, operation hcn.RequestType, setPolicies map[string]*hcn.SetPolicySetting) error { klog.Infof("[IPSetManager Windows] %s operation on set policies is called", operation) - policyRequest, err := getPolicyNetworkRequestMarshal(setPolicies, operation) - if err != nil { - klog.Infof("[IPSetManager Windows] Failed to marshal %s operations sets with error %s", operation, err.Error()) - return err - } + /* + Due to complexities in HNS, we need to do the following: + for (Add) + 1. Add 1st level set policies to HNS + 2. then add nested set policies to HNS - if policyRequest == nil { - klog.Infof("[IPSetManager Windows] No Policies to apply") - return nil + for (delete) + 1. delete nested set policies from HNS + 2. then delete 1st level set policies from HNS + */ + policySettingsOrder := []hcn.SetPolicyType{hcn.SetPolicyTypeIpSet, SetPolicyTypeNestedIPSet} + if operation == hcn.RequestTypeRemove { + policySettingsOrder = []hcn.SetPolicyType{SetPolicyTypeNestedIPSet, hcn.SetPolicyTypeIpSet} } + for _, policyType := range policySettingsOrder { + policyRequest, err := getPolicyNetworkRequestMarshal(setPolicies, policyType) + if err != nil { + klog.Infof("[IPSetManager Windows] Failed to marshal %s operations sets with error %s", operation, err.Error()) + return err + } - requestMessage := &hcn.ModifyNetworkSettingRequest{ - ResourceType: hcn.NetworkResourceTypePolicy, - RequestType: operation, - Settings: policyRequest, - } + if policyRequest == nil { + klog.Infof("[IPSetManager Windows] No Policies to apply") + return nil + } - err = iMgr.ioShim.Hns.ModifyNetworkSettings(network, requestMessage) - if err != nil { - klog.Infof("[IPSetManager Windows] %s operation has failed with error %s", operation, err.Error()) - return err + requestMessage := &hcn.ModifyNetworkSettingRequest{ + ResourceType: hcn.NetworkResourceTypePolicy, + RequestType: operation, + Settings: policyRequest, + } + + err = iMgr.ioShim.Hns.ModifyNetworkSettings(network, requestMessage) + if err != nil { + klog.Infof("[IPSetManager Windows] %s operation has failed with error %s", operation, err.Error()) + return err + } } return nil } @@ -234,37 +250,32 @@ func (setPolicyBuilder *networkPolicyBuilder) setNameExists(setName string) bool return ok } -func getPolicyNetworkRequestMarshal(setPolicySettings map[string]*hcn.SetPolicySetting, operation hcn.RequestType) ([]byte, error) { - policyNetworkRequest := &hcn.PolicyNetworkRequest{ - Policies: make([]hcn.NetworkPolicy, len(setPolicySettings)), - } - +func getPolicyNetworkRequestMarshal(setPolicySettings map[string]*hcn.SetPolicySetting, policyType hcn.SetPolicyType) ([]byte, error) { if len(setPolicySettings) == 0 { klog.Info("[Dataplane Windows] no set policies to apply on network") return nil, nil } - - idx := 0 - policySettingsOrder := []hcn.SetPolicyType{SetPolicyTypeNestedIPSet, hcn.SetPolicyTypeIpSet} - if operation == hcn.RequestTypeRemove { - policySettingsOrder = []hcn.SetPolicyType{hcn.SetPolicyTypeIpSet, SetPolicyTypeNestedIPSet} + klog.Infof("[Dataplane Windows] marshalling %s type of sets", policyType) + policyNetworkRequest := &hcn.PolicyNetworkRequest{ + Policies: make([]hcn.NetworkPolicy, 0), } - for _, policyType := range policySettingsOrder { - for setPol := range setPolicySettings { - if setPolicySettings[setPol].PolicyType != policyType { - continue - } - klog.Infof("Found set pol %+v", setPolicySettings[setPol]) - rawSettings, err := json.Marshal(setPolicySettings[setPol]) - if err != nil { - return nil, err - } - policyNetworkRequest.Policies[idx] = hcn.NetworkPolicy{ + + for setPol := range setPolicySettings { + if setPolicySettings[setPol].PolicyType != policyType { + continue + } + klog.Infof("Found set pol %+v", setPolicySettings[setPol]) + rawSettings, err := json.Marshal(setPolicySettings[setPol]) + if err != nil { + return nil, err + } + policyNetworkRequest.Policies = append( + policyNetworkRequest.Policies, + hcn.NetworkPolicy{ Type: hcn.SetPolicy, Settings: rawSettings, - } - idx++ - } + }, + ) } policyReqSettings, err := json.Marshal(policyNetworkRequest) diff --git a/npm/pkg/dataplane/ipsets/ipsetmanager_windows_test.go b/npm/pkg/dataplane/ipsets/ipsetmanager_windows_test.go new file mode 100644 index 000000000..f2377e32a --- /dev/null +++ b/npm/pkg/dataplane/ipsets/ipsetmanager_windows_test.go @@ -0,0 +1,326 @@ +package ipsets + +import ( + "fmt" + "testing" + + "github.com/Azure/azure-container-networking/common" + "github.com/Azure/azure-container-networking/network/hnswrapper" + testutils "github.com/Azure/azure-container-networking/test/utils" + "github.com/Microsoft/hcsshim/hcn" + "github.com/stretchr/testify/require" +) + +func TestAddToSetWindows(t *testing.T) { + hns := GetHNSFake(t) + io := common.NewMockIOShimWithFakeHNS(hns) + iMgr := NewIPSetManager(iMgrApplyAlwaysCfg, io) + + setMetadata := NewIPSetMetadata(testSetName, Namespace) + iMgr.CreateIPSets([]*IPSetMetadata{setMetadata}) + + err := iMgr.AddToSets([]*IPSetMetadata{setMetadata}, testPodIP, testPodKey) + require.NoError(t, err) + + err = iMgr.AddToSets([]*IPSetMetadata{setMetadata}, "2001:db8:0:0:0:0:2:1", "newpod") + require.NoError(t, err) + + // same IP changed podkey + err = iMgr.AddToSets([]*IPSetMetadata{setMetadata}, testPodIP, "newpod") + require.NoError(t, err) + + listMetadata := NewIPSetMetadata("testipsetlist", KeyLabelOfNamespace) + iMgr.CreateIPSets([]*IPSetMetadata{listMetadata}) + err = iMgr.AddToSets([]*IPSetMetadata{listMetadata}, testPodIP, testPodKey) + require.Error(t, err) + + err = iMgr.ApplyIPSets() + require.NoError(t, err) +} + +func TestDestroyNPMIPSets(t *testing.T) { + iMgr := NewIPSetManager(iMgrApplyAlwaysCfg, common.NewMockIOShim([]testutils.TestCmd{})) + require.NoError(t, iMgr.resetIPSets()) +} + +// create all possible SetTypes +func TestApplyCreationsAndAdds(t *testing.T) { + hns := GetHNSFake(t) + io := common.NewMockIOShimWithFakeHNS(hns) + iMgr := NewIPSetManager(iMgrApplyAlwaysCfg, io) + + iMgr.CreateIPSets([]*IPSetMetadata{TestNSSet.Metadata}) + require.NoError(t, iMgr.AddToSets([]*IPSetMetadata{TestNSSet.Metadata}, "10.0.0.0", "a")) + require.NoError(t, iMgr.AddToSets([]*IPSetMetadata{TestNSSet.Metadata}, "10.0.0.1", "b")) + iMgr.CreateIPSets([]*IPSetMetadata{TestKeyPodSet.Metadata}) + require.NoError(t, iMgr.AddToSets([]*IPSetMetadata{TestKeyPodSet.Metadata}, "10.0.0.5", "c")) + iMgr.CreateIPSets([]*IPSetMetadata{TestKVPodSet.Metadata}) + iMgr.CreateIPSets([]*IPSetMetadata{TestNamedportSet.Metadata}) + iMgr.CreateIPSets([]*IPSetMetadata{TestCIDRSet.Metadata}) + iMgr.CreateIPSets([]*IPSetMetadata{TestKeyNSList.Metadata}) + require.NoError(t, iMgr.AddToLists([]*IPSetMetadata{TestKeyNSList.Metadata}, []*IPSetMetadata{TestNSSet.Metadata, TestKeyPodSet.Metadata})) + iMgr.CreateIPSets([]*IPSetMetadata{TestKVNSList.Metadata}) + require.NoError(t, iMgr.AddToLists([]*IPSetMetadata{TestKVNSList.Metadata}, []*IPSetMetadata{TestKVPodSet.Metadata})) + iMgr.CreateIPSets([]*IPSetMetadata{TestNestedLabelList.Metadata}) + toAddOrUpdateSetMap := map[string]hcn.SetPolicySetting{ + TestNSSet.PrefixName: { + Id: TestNSSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestNSSet.PrefixName, + Values: "10.0.0.0,10.0.0.1", + }, + TestKeyPodSet.PrefixName: { + Id: TestKeyPodSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestKeyPodSet.PrefixName, + Values: "10.0.0.5", + }, + TestKVPodSet.PrefixName: { + Id: TestKVPodSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestKVPodSet.PrefixName, + Values: "", + }, + TestNamedportSet.PrefixName: { + Id: TestNamedportSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestNamedportSet.PrefixName, + Values: "", + }, + TestCIDRSet.PrefixName: { + Id: TestCIDRSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestCIDRSet.PrefixName, + Values: "", + }, + TestKeyNSList.PrefixName: { + Id: TestKeyNSList.HashedName, + PolicyType: SetPolicyTypeNestedIPSet, + Name: TestKeyNSList.PrefixName, + Values: fmt.Sprintf("%s,%s", TestNSSet.HashedName, TestKeyPodSet.HashedName), + }, + TestKVNSList.PrefixName: { + Id: TestKVNSList.HashedName, + PolicyType: SetPolicyTypeNestedIPSet, + Name: TestKVNSList.PrefixName, + Values: TestKVPodSet.HashedName, + }, + TestNestedLabelList.PrefixName: { + Id: TestNestedLabelList.HashedName, + PolicyType: SetPolicyTypeNestedIPSet, + Name: TestNestedLabelList.PrefixName, + Values: "", + }, + } + err := iMgr.ApplyIPSets() + require.NoError(t, err) + verifyHNSCache(t, toAddOrUpdateSetMap, hns) +} + +func TestApplyDeletions(t *testing.T) { + hns := GetHNSFake(t) + io := common.NewMockIOShimWithFakeHNS(hns) + iMgr := NewIPSetManager(iMgrApplyAlwaysCfg, io) + + // Remove members and delete others + iMgr.CreateIPSets([]*IPSetMetadata{TestNSSet.Metadata}) + require.NoError(t, iMgr.AddToSets([]*IPSetMetadata{TestNSSet.Metadata}, "10.0.0.0", "a")) + require.NoError(t, iMgr.AddToSets([]*IPSetMetadata{TestNSSet.Metadata}, "10.0.0.1", "b")) + iMgr.CreateIPSets([]*IPSetMetadata{TestKeyPodSet.Metadata}) + iMgr.CreateIPSets([]*IPSetMetadata{TestKeyNSList.Metadata}) + require.NoError(t, iMgr.AddToLists([]*IPSetMetadata{TestKeyNSList.Metadata}, []*IPSetMetadata{TestNSSet.Metadata, TestKeyPodSet.Metadata})) + require.NoError(t, iMgr.RemoveFromSets([]*IPSetMetadata{TestNSSet.Metadata}, "10.0.0.1", "b")) + require.NoError(t, iMgr.RemoveFromList(TestKeyNSList.Metadata, []*IPSetMetadata{TestKeyPodSet.Metadata})) + iMgr.CreateIPSets([]*IPSetMetadata{TestCIDRSet.Metadata}) + iMgr.DeleteIPSet(TestCIDRSet.PrefixName) + iMgr.CreateIPSets([]*IPSetMetadata{TestNestedLabelList.Metadata}) + iMgr.DeleteIPSet(TestNestedLabelList.PrefixName) + + toDeleteSetNames := []string{TestCIDRSet.PrefixName, TestNestedLabelList.PrefixName} + toAddOrUpdateSetMap := map[string]hcn.SetPolicySetting{ + TestNSSet.PrefixName: { + Id: TestNSSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestNSSet.PrefixName, + Values: "10.0.0.0", + }, + TestKeyPodSet.PrefixName: { + Id: TestKeyPodSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestKeyPodSet.PrefixName, + Values: "", + }, + TestKeyNSList.PrefixName: { + Id: TestKeyNSList.HashedName, + PolicyType: SetPolicyTypeNestedIPSet, + Name: TestKeyNSList.PrefixName, + Values: TestNSSet.HashedName, + }, + } + + err := iMgr.ApplyIPSets() + require.NoError(t, err) + verifyHNSCache(t, toAddOrUpdateSetMap, hns) + verifyDeletedHNSCache(t, toDeleteSetNames, hns) +} + +// TODO test that a reconcile list is updated +func TestFailureOnCreation(t *testing.T) { + hns := GetHNSFake(t) + io := common.NewMockIOShimWithFakeHNS(hns) + iMgr := NewIPSetManager(iMgrApplyAlwaysCfg, io) + + iMgr.CreateIPSets([]*IPSetMetadata{TestNSSet.Metadata}) + require.NoError(t, iMgr.AddToSets([]*IPSetMetadata{TestNSSet.Metadata}, "10.0.0.0", "a")) + require.NoError(t, iMgr.AddToSets([]*IPSetMetadata{TestNSSet.Metadata}, "10.0.0.1", "b")) + iMgr.CreateIPSets([]*IPSetMetadata{TestKeyPodSet.Metadata}) + require.NoError(t, iMgr.AddToSets([]*IPSetMetadata{TestKeyPodSet.Metadata}, "10.0.0.5", "c")) + iMgr.CreateIPSets([]*IPSetMetadata{TestCIDRSet.Metadata}) + iMgr.DeleteIPSet(TestCIDRSet.PrefixName) + + toDeleteSetNames := []string{TestCIDRSet.PrefixName} + toAddOrUpdateSetMap := map[string]hcn.SetPolicySetting{ + TestNSSet.PrefixName: { + Id: TestNSSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestNSSet.PrefixName, + Values: "10.0.0.0,10.0.0.1", + }, + TestKeyPodSet.PrefixName: { + Id: TestKeyPodSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestKeyPodSet.PrefixName, + Values: "10.0.0.5", + }, + } + + err := iMgr.ApplyIPSets() + require.NoError(t, err) + verifyHNSCache(t, toAddOrUpdateSetMap, hns) + verifyDeletedHNSCache(t, toDeleteSetNames, hns) +} + +// TODO test that a reconcile list is updated +func TestFailureOnAddToList(t *testing.T) { + // This exact scenario wouldn't occur. This error happens when the cache is out of date with the kernel. + hns := GetHNSFake(t) + io := common.NewMockIOShimWithFakeHNS(hns) + iMgr := NewIPSetManager(iMgrApplyAlwaysCfg, io) + + iMgr.CreateIPSets([]*IPSetMetadata{TestNSSet.Metadata}) + require.NoError(t, iMgr.AddToSets([]*IPSetMetadata{TestNSSet.Metadata}, "10.0.0.0", "a")) + iMgr.CreateIPSets([]*IPSetMetadata{TestKeyPodSet.Metadata}) + iMgr.CreateIPSets([]*IPSetMetadata{TestKeyNSList.Metadata}) + require.NoError(t, iMgr.AddToLists([]*IPSetMetadata{TestKeyNSList.Metadata}, []*IPSetMetadata{TestNSSet.Metadata, TestKeyPodSet.Metadata})) + iMgr.CreateIPSets([]*IPSetMetadata{TestKVNSList.Metadata}) + require.NoError(t, iMgr.AddToLists([]*IPSetMetadata{TestKVNSList.Metadata}, []*IPSetMetadata{TestNSSet.Metadata})) + iMgr.CreateIPSets([]*IPSetMetadata{TestCIDRSet.Metadata}) + iMgr.DeleteIPSet(TestCIDRSet.PrefixName) + + toDeleteSetNames := []string{TestCIDRSet.PrefixName} + toAddOrUpdateSetMap := map[string]hcn.SetPolicySetting{ + TestNSSet.PrefixName: { + Id: TestNSSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestNSSet.PrefixName, + Values: "10.0.0.0", + }, + TestKeyPodSet.PrefixName: { + Id: TestKeyPodSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestKeyPodSet.PrefixName, + Values: "", + }, + TestKeyNSList.PrefixName: { + Id: TestKeyNSList.HashedName, + PolicyType: SetPolicyTypeNestedIPSet, + Name: TestKeyNSList.PrefixName, + Values: fmt.Sprintf("%s,%s", TestNSSet.HashedName, TestKeyPodSet.HashedName), + }, + TestKVNSList.PrefixName: { + Id: TestKVNSList.HashedName, + PolicyType: SetPolicyTypeNestedIPSet, + Name: TestKVNSList.PrefixName, + Values: TestNSSet.HashedName, + }, + } + + err := iMgr.ApplyIPSets() + require.NoError(t, err) + verifyHNSCache(t, toAddOrUpdateSetMap, hns) + verifyDeletedHNSCache(t, toDeleteSetNames, hns) +} + +// TODO test that a reconcile list is updated +func TestFailureOnFlush(t *testing.T) { + // This exact scenario wouldn't occur. This error happens when the cache is out of date with the kernel. + hns := GetHNSFake(t) + io := common.NewMockIOShimWithFakeHNS(hns) + iMgr := NewIPSetManager(iMgrApplyAlwaysCfg, io) + + iMgr.CreateIPSets([]*IPSetMetadata{TestNSSet.Metadata}) + require.NoError(t, iMgr.AddToSets([]*IPSetMetadata{TestNSSet.Metadata}, "10.0.0.0", "a")) + iMgr.CreateIPSets([]*IPSetMetadata{TestKVPodSet.Metadata}) + iMgr.DeleteIPSet(TestKVPodSet.PrefixName) + iMgr.CreateIPSets([]*IPSetMetadata{TestCIDRSet.Metadata}) + iMgr.DeleteIPSet(TestCIDRSet.PrefixName) + + toDeleteSetNames := []string{TestKVPodSet.PrefixName, TestCIDRSet.PrefixName} + toAddOrUpdateSetMap := map[string]hcn.SetPolicySetting{ + TestNSSet.PrefixName: { + Id: TestNSSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestNSSet.PrefixName, + Values: "10.0.0.0", + }, + } + + err := iMgr.ApplyIPSets() + require.NoError(t, err) + verifyHNSCache(t, toAddOrUpdateSetMap, hns) + verifyDeletedHNSCache(t, toDeleteSetNames, hns) +} + +// TODO test that a reconcile list is updated +func TestFailureOnDeletion(t *testing.T) { + hns := GetHNSFake(t) + io := common.NewMockIOShimWithFakeHNS(hns) + iMgr := NewIPSetManager(iMgrApplyAlwaysCfg, io) + + iMgr.CreateIPSets([]*IPSetMetadata{TestNSSet.Metadata}) + require.NoError(t, iMgr.AddToSets([]*IPSetMetadata{TestNSSet.Metadata}, "10.0.0.0", "a")) + iMgr.CreateIPSets([]*IPSetMetadata{TestKVPodSet.Metadata}) + iMgr.DeleteIPSet(TestKVPodSet.PrefixName) + iMgr.CreateIPSets([]*IPSetMetadata{TestCIDRSet.Metadata}) + iMgr.DeleteIPSet(TestCIDRSet.PrefixName) + + toDeleteSetNames := []string{TestKVPodSet.PrefixName, TestCIDRSet.PrefixName} + toAddOrUpdateSetMap := map[string]hcn.SetPolicySetting{ + TestNSSet.PrefixName: { + Id: TestNSSet.HashedName, + PolicyType: hcn.SetPolicyTypeIpSet, + Name: TestNSSet.PrefixName, + Values: "10.0.0.0", + }, + } + + err := iMgr.ApplyIPSets() + require.NoError(t, err) + verifyHNSCache(t, toAddOrUpdateSetMap, hns) + verifyDeletedHNSCache(t, toDeleteSetNames, hns) +} + +func verifyHNSCache(t *testing.T, expected map[string]hcn.SetPolicySetting, hns *hnswrapper.Hnsv2wrapperFake) { + for setName, setObj := range expected { + cacheObj := hns.Cache.SetPolicy(setObj.Id) + require.NotNil(t, cacheObj) + require.Equal(t, setObj, *cacheObj, fmt.Sprintf("%s mismatch in cache", setName)) + } +} + +func verifyDeletedHNSCache(t *testing.T, deleted []string, hns *hnswrapper.Hnsv2wrapperFake) { + for _, setName := range deleted { + cacheObj := hns.Cache.SetPolicy(setName) + require.Nil(t, cacheObj) + } +} diff --git a/npm/pkg/dataplane/ipsets/testutils_windows.go b/npm/pkg/dataplane/ipsets/testutils_windows.go index 2a97b83d2..3421b3ee8 100644 --- a/npm/pkg/dataplane/ipsets/testutils_windows.go +++ b/npm/pkg/dataplane/ipsets/testutils_windows.go @@ -1,6 +1,26 @@ package ipsets -import testutils "github.com/Azure/azure-container-networking/test/utils" +import ( + "testing" + + "github.com/Azure/azure-container-networking/network/hnswrapper" + testutils "github.com/Azure/azure-container-networking/test/utils" + "github.com/stretchr/testify/require" + "github.com/Microsoft/hcsshim/hcn" +) + +func GetHNSFake(t *testing.T) *hnswrapper.Hnsv2wrapperFake { + hns := hnswrapper.NewHnsv2wrapperFake() + network := &hcn.HostComputeNetwork{ + Id: "1234", + Name: "azure", + } + + _, err := hns.CreateNetwork(network) + require.NoError(t, err) + + return hns +} func GetApplyIPSetsTestCalls(_, _ []*IPSetMetadata) []testutils.TestCmd { return []testutils.TestCmd{} diff --git a/npm/pkg/dataplane/policies/policy_windows.go b/npm/pkg/dataplane/policies/policy_windows.go index 7f1c68257..c2cdfe085 100644 --- a/npm/pkg/dataplane/policies/policy_windows.go +++ b/npm/pkg/dataplane/policies/policy_windows.go @@ -8,6 +8,11 @@ import ( "github.com/Microsoft/hcsshim/hcn" ) +const ( + blockRulePriotity = 3000 + allowRulePriotity = 222 +) + var ( protocolNumMap = map[Protocol]string{ TCP: "6", @@ -69,9 +74,9 @@ func (acl *ACLPolicy) convertToAclSettings() (*NPMACLPolSettings, error) { policySettings.Action = getHCNAction(acl.Target) // TODO need to have better priority handling - policySettings.Priority = uint16(222) + policySettings.Priority = uint16(allowRulePriotity) if policySettings.Action == hcn.ActionTypeBlock { - policySettings.Priority = uint16(3000) + policySettings.Priority = uint16(blockRulePriotity) } if acl.Protocol == "" { acl.Protocol = AnyProtocol @@ -100,9 +105,11 @@ func (acl *ACLPolicy) convertToAclSettings() (*NPMACLPolSettings, error) { policySettings.LocalAddresses = srcListStr policySettings.RemoteAddresses = dstListStr policySettings.RemotePorts = dstPortStr + policySettings.LocalPorts = "" if policySettings.Direction == hcn.DirectionTypeOut { policySettings.LocalAddresses = dstListStr policySettings.LocalPorts = dstPortStr + policySettings.RemotePorts = "" policySettings.RemoteAddresses = srcListStr } diff --git a/npm/pkg/dataplane/policies/policymanager.go b/npm/pkg/dataplane/policies/policymanager.go index ffe8e7eb4..eea62a051 100644 --- a/npm/pkg/dataplane/policies/policymanager.go +++ b/npm/pkg/dataplane/policies/policymanager.go @@ -10,7 +10,18 @@ import ( "k8s.io/klog" ) -const reconcileTimeInMinutes = 5 +// PolicyManagerMode will be used in windows to decide if +// SetPolicies should be used or not +type PolicyManagerMode string + +const ( + // IPSetPolicyMode will references IPSets in policies + IPSetPolicyMode PolicyManagerMode = "IPSet" + // IPPolicyMode will replace ipset names with their value IPs in policies + IPPolicyMode PolicyManagerMode = "IP" + + reconcileTimeInMinutes = 5 +) type PolicyMap struct { cache map[string]*NPMNetworkPolicy @@ -20,6 +31,7 @@ type PolicyManager struct { policyMap *PolicyMap ioShim *common.IOShim staleChains *staleChains + *PolicyManagerCfg sync.Mutex } @@ -40,6 +52,10 @@ func (pMgr *PolicyManager) Initialize() error { return nil } +type PolicyManagerCfg struct { + Mode PolicyManagerMode +} + func (pMgr *PolicyManager) Reset() error { if err := pMgr.reset(); err != nil { return npmerrors.ErrorWrapper(npmerrors.ResetPolicyMgr, false, "failed to reset policy manager", err) diff --git a/npm/pkg/dataplane/policies/policymanager_windows.go b/npm/pkg/dataplane/policies/policymanager_windows.go index acbcb5526..fb206a438 100644 --- a/npm/pkg/dataplane/policies/policymanager_windows.go +++ b/npm/pkg/dataplane/policies/policymanager_windows.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "strings" "github.com/Microsoft/hcsshim/hcn" "k8s.io/klog" @@ -141,6 +140,8 @@ func (pMgr *PolicyManager) removePolicy(policy *NPMNetworkPolicy, endpointList m } epBuilder.compareAndRemovePolicies(rulesToRemove[0].Id, len(rulesToRemove)) + klog.Infof("[DataPlanewindows] Epbuilder ACL policies before removing %+v", epBuilder.aclPolicies) + klog.Infof("[DataPlanewindows] Epbuilder Other policies before removing %+v", epBuilder.otherPolicies) epPolicies, err := epBuilder.getHCNPolicyRequest() if err != nil { aggregateErr = fmt.Errorf("[DataPlanewindows] Skipping removing policies on %s ID Endpoint with %s err\n Previous %w", epID, err.Error(), aggregateErr) @@ -271,15 +272,18 @@ func (epBuilder *endpointPolicyBuilder) compareAndRemovePolicies(ruleIDToRemove // All ACl policies in a given Netpol will have the same ID // starting with "azure-acl-" prefix aclFound := false + toDeleteIndexes := map[int]struct{}{} for i, acl := range epBuilder.aclPolicies { // First check if ID is present and equal, this saves compute cycles to compare both objects if ruleIDToRemove == acl.Id { // Remove the ACL policy from the list - epBuilder.removeACLPolicyAtIndex(i) + klog.Infof("[DataPlane Windows] Found ACL with ID %s and removing it", acl.Id) + toDeleteIndexes[i] = struct{}{} lenOfRulesToRemove-- aclFound = true } } + epBuilder.removeACLPolicyAtIndex(toDeleteIndexes) // If ACl Policies are not found, it means that we might have removed them earlier // or never applied them if !aclFound { @@ -294,19 +298,15 @@ func (epBuilder *endpointPolicyBuilder) compareAndRemovePolicies(ruleIDToRemove } func (epBuilder *endpointPolicyBuilder) resetAllNPMAclPolicies() { - for i, acl := range epBuilder.aclPolicies { - if strings.HasPrefix(acl.Id, "azure-acl-") { - // Remove the ACL policy from the list - epBuilder.removeACLPolicyAtIndex(i) - } - } + epBuilder.aclPolicies = []*NPMACLPolSettings{} } -func (epBuilder *endpointPolicyBuilder) removeACLPolicyAtIndex(i int) { - klog.Infof("[DataPlane Windows] Found ACL with ID %s and removing it", epBuilder.aclPolicies[i].Id) - if i == len(epBuilder.aclPolicies)-1 { - epBuilder.aclPolicies = epBuilder.aclPolicies[:i] - return +func (epBuilder *endpointPolicyBuilder) removeACLPolicyAtIndex(indexes map[int]struct{}) { + tempAclPolicies := []*NPMACLPolSettings{} + for i, acl := range epBuilder.aclPolicies { + if _, ok := indexes[i]; !ok { + tempAclPolicies = append(tempAclPolicies, acl) + } } - epBuilder.aclPolicies = append(epBuilder.aclPolicies[:i], epBuilder.aclPolicies[i+1:]...) + epBuilder.aclPolicies = tempAclPolicies } diff --git a/npm/pkg/dataplane/policies/policymanager_windows_test.go b/npm/pkg/dataplane/policies/policymanager_windows_test.go new file mode 100644 index 000000000..8551c663c --- /dev/null +++ b/npm/pkg/dataplane/policies/policymanager_windows_test.go @@ -0,0 +1,200 @@ +package policies + +import ( + "fmt" + "reflect" + "testing" + + "github.com/Azure/azure-container-networking/common" + "github.com/Azure/azure-container-networking/network/hnswrapper" + "github.com/Azure/azure-container-networking/npm/pkg/dataplane/ipsets" + "github.com/Microsoft/hcsshim/hcn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + expectedACLs = []*hnswrapper.FakeEndpointPolicy{ + { + ID: TestNetworkPolicies[0].ACLs[0].PolicyID, + Protocols: "6", + Direction: "In", + Action: "Block", + LocalAddresses: "azure-npm-3216600258", + RemoteAddresses: "azure-npm-2031808719", + RemotePorts: getPortStr(222, 333), + LocalPorts: "", + Priority: blockRulePriotity, + }, + { + ID: TestNetworkPolicies[0].ACLs[0].PolicyID, + Protocols: "17", + Direction: "In", + Action: "Allow", + LocalAddresses: "azure-npm-3216600258", + RemoteAddresses: "", + LocalPorts: "", + RemotePorts: "", + Priority: allowRulePriotity, + }, + { + ID: TestNetworkPolicies[0].ACLs[0].PolicyID, + Protocols: "17", + Direction: "Out", + Action: "Block", + LocalAddresses: "", + RemoteAddresses: "azure-npm-3216600258", + LocalPorts: "144", + RemotePorts: "", + Priority: blockRulePriotity, + }, + { + ID: TestNetworkPolicies[0].ACLs[0].PolicyID, + Protocols: "256", + Direction: "Out", + Action: "Allow", + LocalAddresses: "", + RemoteAddresses: "azure-npm-3216600258", + LocalPorts: "", + RemotePorts: "", + Priority: allowRulePriotity, + }, + } + + endPointIDList = map[string]string{ + "10.0.0.1": "test1", + "10.0.0.2": "test2", + } +) + +func TestCompareAndRemovePolicies(t *testing.T) { + epbuilder := newEndpointPolicyBuilder() + + testPol := &NPMACLPolSettings{ + Id: "test1", + Protocols: string(TCP), + } + testPol2 := &NPMACLPolSettings{ + Id: "test1", + Protocols: string(UDP), + } + + epbuilder.aclPolicies = append(epbuilder.aclPolicies, []*NPMACLPolSettings{testPol, testPol2}...) + + epbuilder.compareAndRemovePolicies("test1", 2) + + if len(epbuilder.aclPolicies) != 0 { + t.Errorf("Expected 0 policies, got %d", len(epbuilder.aclPolicies)) + } +} + +func TestAddPolicies(t *testing.T) { + pMgr, hns := getPMgr(t) + err := pMgr.AddPolicy(TestNetworkPolicies[0], endPointIDList) + require.NoError(t, err) + + aclID := TestNetworkPolicies[0].ACLs[0].PolicyID + + aclPolicies, err := hns.Cache.ACLPolicies(endPointIDList, aclID) + require.NoError(t, err) + for _, id := range endPointIDList { + acls, ok := aclPolicies[id] + if !ok { + t.Errorf("Expected %s to be in ACLs", id) + } + verifyFakeHNSCacheACLs(t, expectedACLs, acls) + } +} + +func TestRemovePolicies(t *testing.T) { + pMgr, hns := getPMgr(t) + err := pMgr.AddPolicy(TestNetworkPolicies[0], endPointIDList) + require.NoError(t, err) + + aclID := TestNetworkPolicies[0].ACLs[0].PolicyID + + aclPolicies, err := hns.Cache.ACLPolicies(endPointIDList, aclID) + require.NoError(t, err) + for _, id := range endPointIDList { + acls, ok := aclPolicies[id] + if !ok { + t.Errorf("Expected %s to be in ACLs", id) + } + verifyFakeHNSCacheACLs(t, expectedACLs, acls) + } + + err = pMgr.RemovePolicy(TestNetworkPolicies[0].Name, nil) + require.NoError(t, err) + verifyACLCacheIsCleaned(t, hns, len(endPointIDList)) +} + +// Helper functions for UTS + +func getPMgr(t *testing.T) (*PolicyManager, *hnswrapper.Hnsv2wrapperFake) { + hns := ipsets.GetHNSFake(t) + io := common.NewMockIOShimWithFakeHNS(hns) + + for ip, epID := range endPointIDList { + ep := &hcn.HostComputeEndpoint{ + Id: epID, + Name: epID, + IpConfigurations: []hcn.IpConfig{ + { + IpAddress: ip, + }, + }, + } + _, err := hns.CreateEndpoint(ep) + require.NoError(t, err) + } + return NewPolicyManager(io), hns +} + +func verifyFakeHNSCacheACLs(t *testing.T, expected, actual []*hnswrapper.FakeEndpointPolicy) bool { + assert.Equal(t, + len(expected), + len(actual), + fmt.Sprintf("Expected %d ACL, got %d", len(TestNetworkPolicies[0].ACLs), len(actual)), + ) + for i, expectedACL := range expected { + foundACL := false + // While printing actual with %+v it only prints the pointers and it is hard to debug. + // So creating this errStr to print the actual values. + errStr := "" + for j, cacheACL := range actual { + assert.Equal(t, + expectedACL.ID, + actual[i].ID, + fmt.Sprintf("Expected %s, got %s", expectedACL.ID, actual[i].ID), + ) + if reflect.DeepEqual(expectedACL, cacheACL) { + foundACL = true + break + } + errStr += fmt.Sprintf("\n%d: %+v", j, cacheACL) + } + require.True(t, foundACL, fmt.Sprintf("Expected %+v to be in ACLs \n Got: %s ", expectedACL, errStr)) + } + return true +} + +func verifyACLCacheIsCleaned(t *testing.T, hns *hnswrapper.Hnsv2wrapperFake, lenOfEPs int) { + epACLs := hns.Cache.GetAllACLs() + assert.Equal(t, lenOfEPs, len(epACLs)) + for _, acls := range epACLs { + assert.Equal(t, 0, len(acls)) + } +} + +func getPortStr(start, end int32) string { + portStr := fmt.Sprintf("%d", start) + if start == end || end == 0 { + return portStr + } + + for i := start + 1; i <= end; i++ { + portStr += fmt.Sprintf(",%d", i) + } + + return portStr +} diff --git a/npm/pkg/dataplane/policies/testutils.go b/npm/pkg/dataplane/policies/testutils.go index 6fde8dbb7..a64208491 100644 --- a/npm/pkg/dataplane/policies/testutils.go +++ b/npm/pkg/dataplane/policies/testutils.go @@ -56,7 +56,7 @@ var ( Protocol: TCP, }, { - PolicyID: "test2", + PolicyID: "test1", Comment: "comment2", SrcList: []SetInfo{ { @@ -70,7 +70,7 @@ var ( Protocol: UDP, }, { - PolicyID: "test3", + PolicyID: "test1", Comment: "comment3", SrcList: []SetInfo{ { @@ -87,7 +87,7 @@ var ( Protocol: UDP, }, { - PolicyID: "test4", + PolicyID: "test1", Comment: "comment4", SrcList: []SetInfo{ { diff --git a/test/integration/npm/main.go b/test/integration/npm/main.go index 5beca3c99..98021f299 100644 --- a/test/integration/npm/main.go +++ b/test/integration/npm/main.go @@ -39,7 +39,7 @@ var ( Direction: policies.Ingress, }, { - PolicyID: "azure-acl-234", + PolicyID: "azure-acl-123", Target: policies.Allowed, Direction: policies.Ingress, SrcList: []policies.SetInfo{ @@ -62,7 +62,7 @@ var ( func main() { dp, err := dataplane.NewDataPlane(nodeName, common.NewIOShim()) panicOnError(err) - printAndWait() + printAndWait(true) podMetadata := &dataplane.PodMetadata{ PodKey: "a", @@ -80,7 +80,7 @@ func main() { panicOnError(dp.AddToSets([]*ipsets.IPSetMetadata{ipsets.TestNSSet.Metadata}, podMetadataB)) podMetadataC := &dataplane.PodMetadata{ PodKey: "c", - PodIP: "10.240.0.24", + PodIP: "10.240.0.28", NodeName: nodeName, } panicOnError(dp.AddToSets([]*ipsets.IPSetMetadata{ipsets.TestKeyPodSet.Metadata, ipsets.TestNSSet.Metadata}, podMetadataC)) @@ -90,7 +90,7 @@ func main() { panicOnError(dp.ApplyDataPlane()) - printAndWait() + printAndWait(true) panicOnError(dp.AddToLists([]*ipsets.IPSetMetadata{ipsets.TestKeyNSList.Metadata, ipsets.TestKVNSList.Metadata}, []*ipsets.IPSetMetadata{ipsets.TestNSSet.Metadata})) @@ -107,23 +107,66 @@ func main() { dp.DeleteIPSet(ipsets.TestKVPodSet.Metadata) panicOnError(dp.ApplyDataPlane()) - printAndWait() + printAndWait(true) panicOnError(dp.RemoveFromSets([]*ipsets.IPSetMetadata{ipsets.TestNSSet.Metadata}, podMetadata)) dp.DeleteIPSet(ipsets.TestNSSet.Metadata) panicOnError(dp.ApplyDataPlane()) - printAndWait() + printAndWait(true) panicOnError(dp.AddPolicy(testNetPol)) + printAndWait(true) + + panicOnError(dp.RemovePolicy(testNetPol.Name)) + printAndWait(true) + + panicOnError(dp.AddPolicy(testNetPol)) + printAndWait(true) + + podMetadataD = &dataplane.PodMetadata{ + PodKey: "d", + PodIP: "10.240.0.27", + NodeName: nodeName, + } + panicOnError(dp.AddToSets([]*ipsets.IPSetMetadata{ipsets.TestKeyPodSet.Metadata, ipsets.TestNSSet.Metadata}, podMetadataD)) + panicOnError(dp.ApplyDataPlane()) + printAndWait(true) + + panicOnError(dp.RemovePolicy(testNetPol.Name)) panicOnError(dp.AddPolicy(policies.TestNetworkPolicies[0])) panicOnError(dp.AddPolicy(policies.TestNetworkPolicies[1])) - printAndWait() + printAndWait(true) panicOnError(dp.RemovePolicy(policies.TestNetworkPolicies[2].Name)) // no-op panicOnError(dp.AddPolicy(policies.TestNetworkPolicies[2])) - printAndWait() + printAndWait(true) panicOnError(dp.RemovePolicy(policies.TestNetworkPolicies[1].Name)) + + testPolicyManager() +} + +func testPolicyManager() { + pMgr := policies.NewPolicyManager(common.NewIOShim()) + + panicOnError(pMgr.Reset()) + printAndWait(false) + + panicOnError(pMgr.Initialize()) + printAndWait(false) + + panicOnError(pMgr.AddPolicy(policies.TestNetworkPolicies[0], nil)) + printAndWait(false) + + panicOnError(pMgr.AddPolicy(policies.TestNetworkPolicies[1], nil)) + printAndWait(false) + + // remove something that doesn't exist + panicOnError(pMgr.RemovePolicy(policies.TestNetworkPolicies[2].Name, nil)) + printAndWait(false) + + panicOnError(pMgr.AddPolicy(policies.TestNetworkPolicies[2], nil)) + printAndWait(false) } func panicOnError(err error) { @@ -132,10 +175,12 @@ func panicOnError(err error) { } } -func printAndWait() { +func printAndWait(wait bool) { fmt.Printf("#####################\nCompleted running, please check relevant commands, script will resume in %d secs\n#############\n", MaxSleepTime) - for i := 0; i < MaxSleepTime; i++ { - fmt.Print(".") - time.Sleep(time.Second) + if wait { + for i := 0; i < MaxSleepTime; i++ { + fmt.Print(".") + time.Sleep(time.Second) + } } }