diff --git a/cni/network/network.go b/cni/network/network.go index 4b484a541..5a42d67ff 100644 --- a/cni/network/network.go +++ b/cni/network/network.go @@ -21,6 +21,7 @@ import ( "github.com/Azure/azure-container-networking/cns" cnscli "github.com/Azure/azure-container-networking/cns/client" "github.com/Azure/azure-container-networking/common" + "github.com/Azure/azure-container-networking/dhcp" "github.com/Azure/azure-container-networking/iptables" "github.com/Azure/azure-container-networking/netio" "github.com/Azure/azure-container-networking/netlink" @@ -130,7 +131,7 @@ func NewPlugin(name string, nl := netlink.NewNetlink() // Setup network manager. - nm, err := network.NewNetworkManager(nl, platform.NewExecClient(logger), &netio.NetIO{}, network.NewNamespaceClient(), iptables.NewClient()) + nm, err := network.NewNetworkManager(nl, platform.NewExecClient(logger), &netio.NetIO{}, network.NewNamespaceClient(), iptables.NewClient(), dhcp.New(logger)) if err != nil { return nil, err } diff --git a/dhcp/dhcp_linux.go b/dhcp/dhcp_linux.go new file mode 100644 index 000000000..9e7a029c0 --- /dev/null +++ b/dhcp/dhcp_linux.go @@ -0,0 +1,461 @@ +//go:build linux +// +build linux + +package dhcp + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "io" + "net" + "time" + + "github.com/pkg/errors" + "go.uber.org/zap" + "golang.org/x/net/ipv4" + "golang.org/x/sys/unix" +) + +const ( + dhcpDiscover = 1 + bootRequest = 1 + ethPAll = 0x0003 + MaxUDPReceivedPacketSize = 8192 + dhcpServerPort = 67 + dhcpClientPort = 68 + dhcpOpCodeReply = 2 + bootpMinLen = 300 + bytesInAddress = 4 // bytes in an ip address + macBytes = 6 // bytes in a mac address + udpProtocol = 17 + + opRequest = 1 + htypeEthernet = 1 + hlenEthernet = 6 + hops = 0 + secs = 0 + flags = 0x8000 // Broadcast flag +) + +// TransactionID represents a 4-byte DHCP transaction ID as defined in RFC 951, +// Section 3. +// +// The TransactionID is used to match DHCP replies to their original request. +type TransactionID [4]byte + +var ( + magicCookie = []byte{0x63, 0x82, 0x53, 0x63} // DHCP magic cookie + DefaultReadTimeout = 3 * time.Second + DefaultTimeout = 3 * time.Second +) + +type DHCP struct { + logger *zap.Logger +} + +func New(logger *zap.Logger) *DHCP { + return &DHCP{ + logger: logger, + } +} + +type Socket struct { + fd int + remoteAddr unix.SockaddrInet4 +} + +// Linux specific +// returns a writer which should always be closed, even if we return an error +func NewWriteSocket(ifname string, remoteAddr unix.SockaddrInet4) (io.WriteCloser, error) { + fd, err := MakeBroadcastSocket(ifname) + ret := &Socket{ + fd: fd, + remoteAddr: remoteAddr, + } + if err != nil { + return ret, errors.Wrap(err, "could not make dhcp write socket") + } + + return ret, nil +} + +func (s *Socket) Write(packetBytes []byte) (int, error) { + err := unix.Sendto(s.fd, packetBytes, 0, &s.remoteAddr) + if err != nil { + return 0, errors.Wrap(err, "failed unix send to") + } + return len(packetBytes), nil +} + +// returns a reader which should always be closed, even if we return an error +func NewReadSocket(ifname string, timeout time.Duration) (io.ReadCloser, error) { + fd, err := makeListeningSocket(ifname, timeout) + ret := &Socket{ + fd: fd, + } + if err != nil { + return ret, errors.Wrap(err, "could not make dhcp read socket") + } + + return ret, nil +} + +func (s *Socket) Read(p []byte) (n int, err error) { + n, _, innerErr := unix.Recvfrom(s.fd, p, 0) + if innerErr != nil { + return 0, errors.Wrap(err, "failed unix recv from") + } + return n, nil +} + +func (s *Socket) Close() error { + // do not attempt to close fd with -1 as they are not valid + if s.fd == -1 { + return nil + } + // Ensure the file descriptor is closed when done + if err := unix.Close(s.fd); err != nil { + return errors.Wrap(err, "error closing dhcp unix socket") + } + return nil +} + +// GenerateTransactionID generates a random 32-bits number suitable for use as TransactionID +func GenerateTransactionID() (TransactionID, error) { + var xid TransactionID + _, err := rand.Read(xid[:]) + if err != nil { + return xid, errors.Errorf("could not get random number: %v", err) + } + return xid, nil +} + +func makeListeningSocket(ifname string, timeout time.Duration) (int, error) { + // reference: https://manned.org/packet.7 + // starts listening to the specified protocol, or none if zero + // the SockaddrLinkLayer also ensures packets for the htons(unix.ETH_P_IP) prot are received + fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM, int(htons(unix.ETH_P_IP))) + if err != nil { + return fd, errors.Wrap(err, "dhcp socket creation failure") + } + iface, err := net.InterfaceByName(ifname) + if err != nil { + return fd, errors.Wrap(err, "dhcp failed to get interface") + } + llAddr := unix.SockaddrLinklayer{ + Ifindex: iface.Index, + Protocol: htons(unix.ETH_P_IP), + } + err = unix.Bind(fd, &llAddr) + + // set max time waiting without any data received + timeval := unix.NsecToTimeval(timeout.Nanoseconds()) + if innerErr := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeval); innerErr != nil { + return fd, errors.Wrap(innerErr, "could not set timeout on socket") + } + + return fd, errors.Wrap(err, "dhcp failed to bind") +} + +// MakeBroadcastSocket creates a socket that can be passed to unix.Sendto +// that will send packets out to the broadcast address. +func MakeBroadcastSocket(ifname string) (int, error) { + fd, err := makeRawSocket(ifname) + if err != nil { + return fd, err + } + // enables broadcast (disabled by default) + err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1) + if err != nil { + return fd, errors.Wrap(err, "dhcp failed to set sockopt") + } + return fd, nil +} + +// conversion between host and network byte order +func htons(v uint16) uint16 { + var tmp [2]byte + binary.BigEndian.PutUint16(tmp[:], v) + return binary.LittleEndian.Uint16(tmp[:]) +} + +func BindToInterface(fd int, ifname string) error { + return errors.Wrap(unix.BindToDevice(fd, ifname), "failed to bind to device") +} + +// makeRawSocket creates a socket that can be passed to unix.Sendto. +func makeRawSocket(ifname string) (int, error) { + // AF_INET sends via IPv4, SOCK_RAW means create an ip datagram socket (skips udp transport layer, see below) + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_RAW) + if err != nil { + return fd, errors.Wrap(err, "dhcp raw socket creation failure") + } + // Later on when we write to this socket, our packet already contains the header (we create it with MakeRawUDPPacket). + err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_HDRINCL, 1) + if err != nil { + return fd, errors.Wrap(err, "dhcp failed to set hdrincl raw sockopt") + } + err = BindToInterface(fd, ifname) + if err != nil { + return fd, errors.Wrap(err, "dhcp failed to bind to interface") + } + return fd, nil +} + +// Build DHCP Discover Packet +func buildDHCPDiscover(mac net.HardwareAddr, txid TransactionID) ([]byte, error) { + if len(mac) != macBytes { + return nil, errors.Errorf("invalid MAC address length") + } + + var packet bytes.Buffer + + // BOOTP header + packet.WriteByte(opRequest) // op: BOOTREQUEST (1) + packet.WriteByte(htypeEthernet) // htype: Ethernet (1) + packet.WriteByte(hlenEthernet) // hlen: MAC address length (6) + packet.WriteByte(hops) // hops: 0 + packet.Write(txid[:]) // xid: Transaction ID (4 bytes) + err := binary.Write(&packet, binary.BigEndian, uint16(secs)) // secs: Seconds elapsed + if err != nil { + return nil, errors.Wrap(err, "failed to write seconds elapsed") + } + err = binary.Write(&packet, binary.BigEndian, uint16(flags)) // flags: Broadcast flag + if err != nil { + return nil, errors.Wrap(err, "failed to write broadcast flag") + } + + // Client IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + // Your IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + // Server IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + // Gateway IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + + // chaddr: Client hardware address (MAC address) + paddingBytes := 10 + packet.Write(mac) // MAC address (6 bytes) + packet.Write(make([]byte, paddingBytes)) // Padding to 16 bytes + + // sname: Server host name (64 bytes) + serverHostNameBytes := 64 + packet.Write(make([]byte, serverHostNameBytes)) + // file: Boot file name (128 bytes) + bootFileNameBytes := 128 + packet.Write(make([]byte, bootFileNameBytes)) + + // Magic cookie (DHCP) + err = binary.Write(&packet, binary.BigEndian, magicCookie) + if err != nil { + return nil, errors.Wrap(err, "failed to write magic cookie") + } + + // DHCP options (minimal required options for DISCOVER) + packet.Write([]byte{ + 53, 1, 1, // Option 53: DHCP Message Type (1 = DHCP Discover) + 55, 3, 1, 3, 6, // Option 55: Parameter Request List (1 = Subnet Mask, 3 = Router, 6 = DNS) + 255, // End option + }) + + // padding length to 300 bytes + var value uint8 // default is zero + if packet.Len() < bootpMinLen { + packet.Write(bytes.Repeat([]byte{value}, bootpMinLen-packet.Len())) + } + + return packet.Bytes(), nil +} + +// MakeRawUDPPacket converts a payload (a serialized packet) into a +// raw UDP packet for the specified serverAddr from the specified clientAddr. +func MakeRawUDPPacket(payload []byte, serverAddr, clientAddr net.UDPAddr) ([]byte, error) { + udpBytes := 8 + udp := make([]byte, udpBytes) + binary.BigEndian.PutUint16(udp[:2], uint16(clientAddr.Port)) + binary.BigEndian.PutUint16(udp[2:4], uint16(serverAddr.Port)) + totalLen := uint16(udpBytes + len(payload)) + binary.BigEndian.PutUint16(udp[4:6], totalLen) + binary.BigEndian.PutUint16(udp[6:8], 0) // try to offload the checksum + + headerVersion := 4 + headerLen := 20 + headerTTL := 64 + + h := ipv4.Header{ + Version: headerVersion, // nolint + Len: headerLen, // nolint + TotalLen: headerLen + len(udp) + len(payload), + TTL: headerTTL, + Protocol: udpProtocol, // UDP + Dst: serverAddr.IP, + Src: clientAddr.IP, + } + ret, err := h.Marshal() + if err != nil { + return nil, errors.Wrap(err, "failed to marshal when making udp packet") + } + ret = append(ret, udp...) + ret = append(ret, payload...) + return ret, nil +} + +// Receive DHCP response packet using reader +func (c *DHCP) receiveDHCPResponse(ctx context.Context, reader io.ReadCloser, xid TransactionID) error { + recvErrors := make(chan error, 1) + // Recvfrom is a blocking call, so if something goes wrong with its timeout it won't return. + + // Additionally, the timeout on the socket (on the Read(...)) call is how long until the socket times out and gives an error, + // but it won't error if we do get some sort of data within the time out period. + + // If we get some data (even if it is not the packet we are looking for, like wrong txid, wrong response opcode etc.) + // then we continue in the for loop. We then call recvfrom again which will reset the timeout period + // Without the secondary timeout at the bottom of the function, we could stay stuck in the for loop as long as we receive packets. + go func(errs chan<- error) { + // loop will only exit if there is an error, context canceled, or we find our reply packet + for { + if ctx.Err() != nil { + errs <- ctx.Err() + return + } + + buf := make([]byte, MaxUDPReceivedPacketSize) + // Blocks until data received or timeout period is reached + n, innerErr := reader.Read(buf) + if innerErr != nil { + errs <- innerErr + return + } + // check header + var iph ipv4.Header + if err := iph.Parse(buf[:n]); err != nil { + // skip non-IP data + continue + } + if iph.Protocol != udpProtocol { + // skip non-UDP packets + continue + } + udph := buf[iph.Len:n] + // source is from dhcp server if receiving + srcPort := int(binary.BigEndian.Uint16(udph[0:2])) + if srcPort != dhcpServerPort { + continue + } + // client is to dhcp client if receiving + dstPort := int(binary.BigEndian.Uint16(udph[2:4])) + if dstPort != dhcpClientPort { + continue + } + // check payload + pLen := int(binary.BigEndian.Uint16(udph[4:6])) + payload := buf[iph.Len+8 : iph.Len+pLen] + + // retrieve opcode from payload + opcode := payload[0] // opcode is first byte + // retrieve txid from payload + txidOffset := 4 // after 4 bytes, the txid starts + // the txid is 4 bytes, so we take four bytes after the offset + txid := payload[txidOffset : txidOffset+4] + + c.logger.Info("Received packet", zap.Int("opCode", int(opcode)), zap.Any("transactionID", TransactionID(txid))) + if opcode != dhcpOpCodeReply { + continue // opcode is not a reply, so continue + } + + if TransactionID(txid) == xid { + break + } + } + // only occurs if we find our reply packet successfully + // a nil error means a reply was found for this txid + recvErrors <- nil + }(recvErrors) + + // sends a message on repeat after timeout, but only the first one matters + ticker := time.NewTicker(DefaultReadTimeout) + defer ticker.Stop() + + select { + case err := <-recvErrors: + if err != nil { + return errors.Wrap(err, "error during receiving") + } + case <-ticker.C: + return errors.New("timed out waiting for replies") + } + return nil +} + +// Issues a DHCP Discover packet from the nic specified by mac and name ifname +// Returns nil if a reply to the transaction was received, or error if time out +// Does not return the DHCP Offer that was received from the DHCP server +func (c *DHCP) DiscoverRequest(ctx context.Context, mac net.HardwareAddr, ifname string) error { + txid, err := GenerateTransactionID() + if err != nil { + return errors.Wrap(err, "failed to generate random transaction id") + } + + // Used in later steps + raddr := &net.UDPAddr{IP: net.IPv4bcast, Port: dhcpServerPort} + laddr := &net.UDPAddr{IP: net.IPv4zero, Port: dhcpClientPort} + var destination [net.IPv4len]byte + copy(destination[:], raddr.IP.To4()) + + // Build a DHCP discover packet + dhcpPacket, err := buildDHCPDiscover(mac, txid) + if err != nil { + return errors.Wrap(err, "failed to build dhcp discover packet") + } + // Make UDP packet from dhcp packet in previous steps + packetToSendBytes, err := MakeRawUDPPacket(dhcpPacket, *raddr, *laddr) + if err != nil { + return errors.Wrap(err, "error making raw udp packet") + } + + // Make writer + remoteAddr := unix.SockaddrInet4{Port: laddr.Port, Addr: destination} + writer, err := NewWriteSocket(ifname, remoteAddr) + defer func() { + // Ensure the file descriptor is closed when done + if err = writer.Close(); err != nil { + c.logger.Error("Error closing dhcp writer socket:", zap.Error(err)) + } + }() + if err != nil { + return errors.Wrap(err, "failed to make broadcast socket") + } + + // Make reader + deadline, ok := ctx.Deadline() + if !ok { + return errors.New("no deadline for passed in context") + } + timeout := time.Until(deadline) + // note: if the write/send takes a long time DiscoverRequest might take a bit longer than the deadline + reader, err := NewReadSocket(ifname, timeout) + defer func() { + // Ensure the file descriptor is closed when done + if err = reader.Close(); err != nil { + c.logger.Error("Error closing dhcp reader socket:", zap.Error(err)) + } + }() + if err != nil { + return errors.Wrap(err, "failed to make listening socket") + } + + // Once writer and reader created, start sending and receiving + _, err = writer.Write(packetToSendBytes) + if err != nil { + return errors.Wrap(err, "failed to send dhcp discover packet") + } + + c.logger.Info("DHCP Discover packet was sent successfully", zap.Any("transactionID", txid)) + + // Wait for DHCP response (Offer) + res := c.receiveDHCPResponse(ctx, reader, txid) + return res +} diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go new file mode 100644 index 000000000..7b23dbeef --- /dev/null +++ b/dhcp/dhcp_windows.go @@ -0,0 +1,22 @@ +package dhcp + +import ( + "context" + "net" + + "go.uber.org/zap" +) + +type DHCP struct { + logger *zap.Logger +} + +func New(logger *zap.Logger) *DHCP { + return &DHCP{ + logger: logger, + } +} + +func (c *DHCP) DiscoverRequest(_ context.Context, _ net.HardwareAddr, _ string) error { + return nil +} diff --git a/network/dhcp.go b/network/dhcp.go new file mode 100644 index 000000000..82be4bd97 --- /dev/null +++ b/network/dhcp.go @@ -0,0 +1,16 @@ +package network + +import ( + "context" + "net" +) + +type dhcpClient interface { + DiscoverRequest(context.Context, net.HardwareAddr, string) error +} + +type mockDHCP struct{} + +func (netns *mockDHCP) DiscoverRequest(context.Context, net.HardwareAddr, string) error { + return nil +} diff --git a/network/endpoint.go b/network/endpoint.go index c6daf4701..e5b55a0b3 100644 --- a/network/endpoint.go +++ b/network/endpoint.go @@ -169,6 +169,7 @@ func (nw *network) newEndpoint( netioCli netio.NetIOInterface, nsc NamespaceClientInterface, iptc ipTablesClient, + dhcpc dhcpClient, epInfo *EndpointInfo, ) (*endpoint, error) { var ep *endpoint @@ -182,7 +183,7 @@ func (nw *network) newEndpoint( // Call the platform implementation. // Pass nil for epClient and will be initialized in newendpointImpl - ep, err = nw.newEndpointImpl(apipaCli, nl, plc, netioCli, nil, nsc, iptc, epInfo) + ep, err = nw.newEndpointImpl(apipaCli, nl, plc, netioCli, nil, nsc, iptc, dhcpc, epInfo) if err != nil { return nil, err } @@ -195,7 +196,7 @@ func (nw *network) newEndpoint( // DeleteEndpoint deletes an existing endpoint from the network. func (nw *network) deleteEndpoint(nl netlink.NetlinkInterface, plc platform.ExecClient, nioc netio.NetIOInterface, nsc NamespaceClientInterface, - iptc ipTablesClient, endpointID string, + iptc ipTablesClient, dhcpc dhcpClient, endpointID string, ) error { var err error @@ -215,7 +216,7 @@ func (nw *network) deleteEndpoint(nl netlink.NetlinkInterface, plc platform.Exec // Call the platform implementation. // Pass nil for epClient and will be initialized in deleteEndpointImpl - err = nw.deleteEndpointImpl(nl, plc, nil, nioc, nsc, iptc, ep) + err = nw.deleteEndpointImpl(nl, plc, nil, nioc, nsc, iptc, dhcpc, ep) if err != nil { return err } diff --git a/network/endpoint_linux.go b/network/endpoint_linux.go index faca6c4c9..3f2d8aa77 100644 --- a/network/endpoint_linux.go +++ b/network/endpoint_linux.go @@ -57,6 +57,7 @@ func (nw *network) newEndpointImpl( testEpClient EndpointClient, nsc NamespaceClientInterface, iptc ipTablesClient, + dhcpclient dhcpClient, epInfo *EndpointInfo, ) (*endpoint, error) { var ( @@ -167,7 +168,7 @@ func (nw *network) newEndpointImpl( epClient = NewLinuxBridgeEndpointClient(nw.extIf, hostIfName, contIfName, nw.Mode, nl, plc) } else if epInfo.NICType == cns.NodeNetworkInterfaceFrontendNIC { logger.Info("Secondary client") - epClient = NewSecondaryEndpointClient(nl, netioCli, plc, nsc, ep) + epClient = NewSecondaryEndpointClient(nl, netioCli, plc, nsc, dhcpclient, ep) } else { logger.Info("Transparent client") epClient = NewTransparentEndpointClient(nw.extIf, hostIfName, contIfName, nw.Mode, nl, netioCli, plc) @@ -265,7 +266,7 @@ func (nw *network) newEndpointImpl( // deleteEndpointImpl deletes an existing endpoint from the network. func (nw *network) deleteEndpointImpl(nl netlink.NetlinkInterface, plc platform.ExecClient, epClient EndpointClient, nioc netio.NetIOInterface, nsc NamespaceClientInterface, - iptc ipTablesClient, ep *endpoint, + iptc ipTablesClient, dhcpc dhcpClient, ep *endpoint, ) error { // Delete the veth pair by deleting one of the peer interfaces. // Deleting the host interface is more convenient since it does not require @@ -287,7 +288,7 @@ func (nw *network) deleteEndpointImpl(nl netlink.NetlinkInterface, plc platform. } else { // delete if secondary interfaces populated or endpoint of type delegated (new way) if len(ep.SecondaryInterfaces) > 0 || ep.NICType == cns.NodeNetworkInterfaceFrontendNIC { - epClient = NewSecondaryEndpointClient(nl, nioc, plc, nsc, ep) + epClient = NewSecondaryEndpointClient(nl, nioc, plc, nsc, dhcpc, ep) epClient.DeleteEndpointRules(ep) //nolint:errcheck // ignore error epClient.DeleteEndpoints(ep) diff --git a/network/endpoint_test.go b/network/endpoint_test.go index 3835af0f2..e9edcef4e 100644 --- a/network/endpoint_test.go +++ b/network/endpoint_test.go @@ -186,7 +186,7 @@ var _ = Describe("Test Endpoint", func() { It("Should be added", func() { // Add endpoint with valid id ep, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), epInfo) + netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, epInfo) Expect(err).NotTo(HaveOccurred()) Expect(ep).NotTo(BeNil()) Expect(ep.Id).To(Equal(epInfo.EndpointID)) @@ -198,7 +198,7 @@ var _ = Describe("Test Endpoint", func() { extIf: &externalInterface{IPv4Gateway: net.ParseIP("192.168.0.1")}, } ep, err := nw2.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), epInfo) + netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, epInfo) Expect(err).NotTo(HaveOccurred()) Expect(ep).NotTo(BeNil()) Expect(ep.Id).To(Equal(epInfo.EndpointID)) @@ -216,7 +216,7 @@ var _ = Describe("Test Endpoint", func() { Expect(err).ToNot(HaveOccurred()) // Adding endpoint with same id should fail and delete should cleanup the state ep2, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), mockCli, NewMockNamespaceClient(), iptables.NewClient(), epInfo) + netio.NewMockNetIO(false, 0), mockCli, NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, epInfo) Expect(err).To(HaveOccurred()) Expect(ep2).To(BeNil()) assert.Contains(GinkgoT(), err.Error(), "Endpoint already exists") @@ -226,17 +226,17 @@ var _ = Describe("Test Endpoint", func() { // Adding an endpoint with an id. mockCli := NewMockEndpointClient(nil) ep2, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), mockCli, NewMockNamespaceClient(), iptables.NewClient(), epInfo) + netio.NewMockNetIO(false, 0), mockCli, NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, epInfo) Expect(err).ToNot(HaveOccurred()) Expect(ep2).ToNot(BeNil()) Expect(len(mockCli.endpoints)).To(Equal(1)) // Deleting the endpoint //nolint:errcheck // ignore error - nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), ep2) + nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, ep2) Expect(len(mockCli.endpoints)).To(Equal(0)) // Deleting same endpoint with same id should not fail //nolint:errcheck // ignore error - nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), ep2) + nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, ep2) Expect(len(mockCli.endpoints)).To(Equal(0)) }) }) @@ -256,7 +256,7 @@ var _ = Describe("Test Endpoint", func() { extIf: &externalInterface{IPv4Gateway: net.ParseIP("192.168.0.1")}, } ep, err := nw2.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), epInfo) + netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, epInfo) Expect(err).NotTo(HaveOccurred()) Expect(ep).NotTo(BeNil()) Expect(ep.Id).To(Equal(epInfo.EndpointID)) @@ -282,7 +282,7 @@ var _ = Describe("Test Endpoint", func() { extIf: &externalInterface{IPv4Gateway: net.ParseIP("192.168.0.1")}, } ep, err := nw2.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), epInfo) + netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, epInfo) Expect(err).NotTo(HaveOccurred()) Expect(ep).NotTo(BeNil()) Expect(ep.Id).To(Equal(epInfo.EndpointID)) @@ -309,11 +309,11 @@ var _ = Describe("Test Endpoint", func() { } return nil - }), NewMockNamespaceClient(), iptables.NewClient(), epInfo) + }), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, epInfo) Expect(err).To(HaveOccurred()) Expect(ep).To(BeNil()) ep, err = nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), epInfo) + netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, epInfo) Expect(err).NotTo(HaveOccurred()) Expect(ep).NotTo(BeNil()) Expect(ep.Id).To(Equal(epInfo.EndpointID)) @@ -342,14 +342,14 @@ var _ = Describe("Test Endpoint", func() { It("Should not add endpoint to the network when there is an error", func() { secondaryEpInfo.MacAddress = netio.BadHwAddr // mock netlink will fail to set link state on bad eth ep, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), iptables.NewClient(), secondaryEpInfo) + netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, secondaryEpInfo) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(Equal("SecondaryEndpointClient Error: " + netlink.ErrorMockNetlink.Error())) Expect(ep).To(BeNil()) // should not panic or error when going through the unified endpoint impl flow with only the delegated nic type fields secondaryEpInfo.MacAddress = netio.HwAddr ep, err = nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), iptables.NewClient(), secondaryEpInfo) + netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, secondaryEpInfo) Expect(err).ToNot(HaveOccurred()) Expect(ep.Id).To(Equal(epInfo.EndpointID)) }) @@ -357,12 +357,12 @@ var _ = Describe("Test Endpoint", func() { It("Should add endpoint when there are no errors", func() { secondaryEpInfo.MacAddress = netio.HwAddr ep, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), iptables.NewClient(), secondaryEpInfo) + netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, secondaryEpInfo) Expect(err).ToNot(HaveOccurred()) Expect(ep.Id).To(Equal(epInfo.EndpointID)) ep, err = nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), iptables.NewClient(), epInfo) + netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, epInfo) Expect(err).ToNot(HaveOccurred()) Expect(ep.Id).To(Equal(epInfo.EndpointID)) }) diff --git a/network/endpoint_windows.go b/network/endpoint_windows.go index 3e927b0b6..44bac6aea 100644 --- a/network/endpoint_windows.go +++ b/network/endpoint_windows.go @@ -150,6 +150,7 @@ func (nw *network) newEndpointImpl( _ EndpointClient, _ NamespaceClientInterface, _ ipTablesClient, + _ dhcpClient, epInfo *EndpointInfo, ) (*endpoint, error) { if epInfo.NICType == cns.BackendNIC { @@ -521,7 +522,7 @@ func (nw *network) newEndpointImplHnsV2(cli apipaClient, epInfo *EndpointInfo) ( // deleteEndpointImpl deletes an existing endpoint from the network. func (nw *network) deleteEndpointImpl(_ netlink.NetlinkInterface, _ platform.ExecClient, _ EndpointClient, _ netio.NetIOInterface, _ NamespaceClientInterface, - _ ipTablesClient, ep *endpoint, + _ ipTablesClient, _ dhcpClient, ep *endpoint, ) error { // endpoint deletion is not required for IB if ep.NICType == cns.BackendNIC { diff --git a/network/endpoint_windows_test.go b/network/endpoint_windows_test.go index 4b6588cd3..c9ebb2634 100644 --- a/network/endpoint_windows_test.go +++ b/network/endpoint_windows_test.go @@ -107,7 +107,8 @@ func TestDeleteEndpointImplHnsV2ForIB(t *testing.T) { } mockCli := NewMockEndpointClient(nil) - err := nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), &ep) + err := nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, + netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, &ep) if err != nil { t.Fatal("endpoint deletion for IB is executed") } @@ -134,7 +135,8 @@ func TestDeleteEndpointImplHnsV2WithEmptyHNSID(t *testing.T) { // should return nil because HnsID is empty mockCli := NewMockEndpointClient(nil) - err := nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), &ep) + err := nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, + netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, &ep) if err != nil { t.Fatal("endpoint deletion gets executed") } @@ -492,7 +494,7 @@ func TestNewEndpointImplHnsv2ForIBHappyPath(t *testing.T) { // Happy Path endpoint, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), epInfo) + netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, epInfo) if endpoint != nil || err != nil { t.Fatalf("Endpoint is created for IB due to %v", err) @@ -522,7 +524,7 @@ func TestNewEndpointImplHnsv2ForIBUnHappyPath(t *testing.T) { // Set UnHappy Path _, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(true), - netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), epInfo) + netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, epInfo) if err == nil { t.Fatal("Failed to test Endpoint creation for IB with unhappy path") @@ -562,7 +564,8 @@ func TestCreateAndDeleteEndpointImplHnsv2ForDelegatedHappyPath(t *testing.T) { } mockCli := NewMockEndpointClient(nil) - err = nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), ep) + err = nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, + netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), &mockDHCP{}, ep) if err != nil { t.Fatalf("Failed to delete endpoint for Delegated NIC due to %v", err) } diff --git a/network/manager.go b/network/manager.go index 607ec6cc1..6b156d52c 100644 --- a/network/manager.go +++ b/network/manager.go @@ -86,6 +86,7 @@ type networkManager struct { plClient platform.ExecClient nsClient NamespaceClientInterface iptablesClient ipTablesClient + dhcpClient dhcpClient sync.Mutex } @@ -123,7 +124,7 @@ type NetworkManager interface { // Creates a new network manager. func NewNetworkManager(nl netlink.NetlinkInterface, plc platform.ExecClient, netioCli netio.NetIOInterface, nsc NamespaceClientInterface, - iptc ipTablesClient, + iptc ipTablesClient, dhcpc dhcpClient, ) (NetworkManager, error) { nm := &networkManager{ ExternalInterfaces: make(map[string]*externalInterface), @@ -132,6 +133,7 @@ func NewNetworkManager(nl netlink.NetlinkInterface, plc platform.ExecClient, net netio: netioCli, nsClient: nsc, iptablesClient: iptc, + dhcpClient: dhcpc, } return nm, nil @@ -386,7 +388,7 @@ func (nm *networkManager) createEndpoint(cli apipaClient, networkID string, epIn } } - ep, err := nw.newEndpoint(cli, nm.netlink, nm.plClient, nm.netio, nm.nsClient, nm.iptablesClient, epInfo) + ep, err := nw.newEndpoint(cli, nm.netlink, nm.plClient, nm.netio, nm.nsClient, nm.iptablesClient, nm.dhcpClient, epInfo) if err != nil { return nil, err } @@ -395,7 +397,7 @@ func (nm *networkManager) createEndpoint(cli apipaClient, networkID string, epIn if err != nil { logger.Error("Create endpoint failure", zap.Error(err)) logger.Info("Cleanup resources") - delErr := nw.deleteEndpoint(nm.netlink, nm.plClient, nm.netio, nm.nsClient, nm.iptablesClient, ep.Id) + delErr := nw.deleteEndpoint(nm.netlink, nm.plClient, nm.netio, nm.nsClient, nm.iptablesClient, nm.dhcpClient, ep.Id) if delErr != nil { logger.Error("Deleting endpoint after create endpoint failure failed with", zap.Error(delErr)) } @@ -489,7 +491,7 @@ func (nm *networkManager) DeleteEndpoint(networkID, endpointID string, epInfo *E return err } - err = nw.deleteEndpoint(nm.netlink, nm.plClient, nm.netio, nm.nsClient, nm.iptablesClient, endpointID) + err = nw.deleteEndpoint(nm.netlink, nm.plClient, nm.netio, nm.nsClient, nm.iptablesClient, nm.dhcpClient, endpointID) if err != nil { return err } @@ -531,7 +533,7 @@ func (nm *networkManager) DeleteEndpointState(networkID string, epInfo *Endpoint } logger.Info("Deleting endpoint with", zap.String("Endpoint Info: ", epInfo.PrettyString()), zap.String("HNISID : ", ep.HnsId)) - err := nw.deleteEndpointImpl(netlink.NewNetlink(), platform.NewExecClient(logger), nil, nil, nil, nil, ep) + err := nw.deleteEndpointImpl(netlink.NewNetlink(), platform.NewExecClient(logger), nil, nil, nil, nil, nil, ep) if err != nil { return err } diff --git a/network/secondary_endpoint_client_linux.go b/network/secondary_endpoint_client_linux.go index 46fc3c26f..4a41373e7 100644 --- a/network/secondary_endpoint_client_linux.go +++ b/network/secondary_endpoint_client_linux.go @@ -1,8 +1,10 @@ package network import ( + "context" "os" "strings" + "time" "github.com/Azure/azure-container-networking/netio" "github.com/Azure/azure-container-networking/netlink" @@ -25,6 +27,7 @@ type SecondaryEndpointClient struct { plClient platform.ExecClient netUtilsClient networkutils.NetworkUtils nsClient NamespaceClientInterface + dhcpClient dhcpClient ep *endpoint } @@ -33,6 +36,7 @@ func NewSecondaryEndpointClient( nioc netio.NetIOInterface, plc platform.ExecClient, nsc NamespaceClientInterface, + dhcpClient dhcpClient, endpoint *endpoint, ) *SecondaryEndpointClient { client := &SecondaryEndpointClient{ @@ -41,6 +45,7 @@ func NewSecondaryEndpointClient( plClient: plc, netUtilsClient: networkutils.NewNetworkUtils(nl, plc), nsClient: nsc, + dhcpClient: dhcpClient, ep: endpoint, } @@ -127,6 +132,19 @@ func (client *SecondaryEndpointClient) ConfigureContainerInterfacesAndRoutes(epI ifInfo.Routes = append(ifInfo.Routes, epInfo.Routes...) + // issue dhcp discover packet to ensure mapping created for dns via wireserver to work + // we do not use the response for anything + numSecs := 3 + timeout := time.Duration(numSecs) * time.Second + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout)) + defer cancel() + logger.Info("Sending DHCP packet", zap.Any("macAddress", epInfo.MacAddress), zap.String("ifName", epInfo.IfName)) + err := client.dhcpClient.DiscoverRequest(ctx, epInfo.MacAddress, epInfo.IfName) + if err != nil { + return errors.Wrapf(err, "failed to issue dhcp discover packet to create mapping in host") + } + logger.Info("Finished configuring container interfaces and routes for secondary endpoint client") + return nil } diff --git a/network/secondary_endpoint_linux_test.go b/network/secondary_endpoint_linux_test.go index 2d8ca0556..150efd1de 100644 --- a/network/secondary_endpoint_linux_test.go +++ b/network/secondary_endpoint_linux_test.go @@ -36,6 +36,7 @@ func TestSecondaryAddEndpoints(t *testing.T) { netUtilsClient: networkutils.NewNetworkUtils(nl, plc), netioshim: netio.NewMockNetIO(false, 0), ep: &endpoint{SecondaryInterfaces: make(map[string]*InterfaceInfo)}, + dhcpClient: &mockDHCP{}, }, epInfo: &EndpointInfo{MacAddress: mac}, wantErr: false, @@ -255,6 +256,7 @@ func TestSecondaryConfigureContainerInterfacesAndRoutes(t *testing.T) { plClient: platform.NewMockExecClient(false), netUtilsClient: networkutils.NewNetworkUtils(nl, plc), netioshim: netio.NewMockNetIO(false, 0), + dhcpClient: &mockDHCP{}, ep: &endpoint{SecondaryInterfaces: map[string]*InterfaceInfo{"eth1": {Name: "eth1"}}}, }, epInfo: &EndpointInfo{ @@ -280,6 +282,7 @@ func TestSecondaryConfigureContainerInterfacesAndRoutes(t *testing.T) { plClient: platform.NewMockExecClient(false), netUtilsClient: networkutils.NewNetworkUtils(netlink.NewMockNetlink(true, ""), plc), netioshim: netio.NewMockNetIO(false, 0), + dhcpClient: &mockDHCP{}, ep: &endpoint{SecondaryInterfaces: map[string]*InterfaceInfo{"eth1": {Name: "eth1"}}}, }, epInfo: &EndpointInfo{ @@ -301,6 +304,7 @@ func TestSecondaryConfigureContainerInterfacesAndRoutes(t *testing.T) { plClient: platform.NewMockExecClient(false), netUtilsClient: networkutils.NewNetworkUtils(nl, plc), netioshim: netio.NewMockNetIO(true, 1), + dhcpClient: &mockDHCP{}, ep: &endpoint{SecondaryInterfaces: map[string]*InterfaceInfo{"eth1": {Name: "eth1"}}}, }, epInfo: &EndpointInfo{ @@ -327,6 +331,7 @@ func TestSecondaryConfigureContainerInterfacesAndRoutes(t *testing.T) { plClient: platform.NewMockExecClient(false), netUtilsClient: networkutils.NewNetworkUtils(nl, plc), netioshim: netio.NewMockNetIO(false, 0), + dhcpClient: &mockDHCP{}, ep: &endpoint{SecondaryInterfaces: map[string]*InterfaceInfo{"eth1": {Name: "eth1"}}}, }, epInfo: &EndpointInfo{ @@ -348,6 +353,7 @@ func TestSecondaryConfigureContainerInterfacesAndRoutes(t *testing.T) { plClient: platform.NewMockExecClient(false), netUtilsClient: networkutils.NewNetworkUtils(nl, plc), netioshim: netio.NewMockNetIO(false, 0), + dhcpClient: &mockDHCP{}, ep: &endpoint{SecondaryInterfaces: map[string]*InterfaceInfo{"eth1": {Name: "eth1"}}}, }, epInfo: &EndpointInfo{