2017-02-12 13:03:39 +03:00
|
|
|
// Copyright 2017 Microsoft. All rights reserved.
|
|
|
|
// MIT License
|
2016-04-19 04:09:24 +03:00
|
|
|
|
2022-08-17 20:53:57 +03:00
|
|
|
//go:build linux
|
2018-09-20 01:29:42 +03:00
|
|
|
// +build linux
|
|
|
|
|
2016-04-19 04:09:24 +03:00
|
|
|
package netlink
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"sync"
|
|
|
|
"sync/atomic"
|
|
|
|
"syscall"
|
|
|
|
|
2016-10-07 00:40:29 +03:00
|
|
|
"github.com/Azure/azure-container-networking/log"
|
2016-06-05 16:42:14 +03:00
|
|
|
"golang.org/x/sys/unix"
|
2016-04-19 04:09:24 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
// Represents a netlink socket.
|
|
|
|
type socket struct {
|
|
|
|
fd int
|
|
|
|
sa unix.SockaddrNetlink
|
|
|
|
pid uint32
|
|
|
|
seq uint32
|
|
|
|
sync.Mutex
|
|
|
|
}
|
|
|
|
|
|
|
|
// Default netlink socket.
|
2021-09-03 00:33:18 +03:00
|
|
|
var (
|
|
|
|
s *socket
|
|
|
|
m sync.Mutex
|
|
|
|
)
|
2016-04-19 04:09:24 +03:00
|
|
|
|
|
|
|
// Returns a reference to the default netlink socket.
|
|
|
|
func getSocket() (*socket, error) {
|
|
|
|
var err error
|
2016-11-23 02:24:28 +03:00
|
|
|
|
|
|
|
m.Lock()
|
|
|
|
defer m.Unlock()
|
|
|
|
|
|
|
|
if s == nil {
|
|
|
|
s, err = newSocket()
|
|
|
|
}
|
|
|
|
|
2016-04-19 04:09:24 +03:00
|
|
|
return s, err
|
|
|
|
}
|
|
|
|
|
2016-11-23 02:24:28 +03:00
|
|
|
// ResetSocket deletes the default netlink socket.
|
|
|
|
func ResetSocket() {
|
|
|
|
m.Lock()
|
|
|
|
defer m.Unlock()
|
|
|
|
|
|
|
|
s = nil
|
|
|
|
}
|
|
|
|
|
2016-04-19 04:09:24 +03:00
|
|
|
// Creates a new netlink socket object.
|
|
|
|
func newSocket() (*socket, error) {
|
|
|
|
fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
|
|
|
if err != nil {
|
2016-11-23 02:24:28 +03:00
|
|
|
log.Debugf("[netlink] Failed to create socket, err=%v\n", err)
|
2016-04-19 04:09:24 +03:00
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
s := &socket{
|
|
|
|
fd: fd,
|
|
|
|
pid: uint32(unix.Getpid()),
|
|
|
|
seq: 0,
|
|
|
|
}
|
|
|
|
|
|
|
|
s.sa.Family = unix.AF_NETLINK
|
|
|
|
|
|
|
|
err = unix.Bind(fd, &s.sa)
|
|
|
|
if err != nil {
|
|
|
|
unix.Close(fd)
|
2016-11-23 02:24:28 +03:00
|
|
|
log.Debugf("[netlink] Failed to bind socket, err=%v\n", err)
|
2016-04-19 04:09:24 +03:00
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2016-11-23 02:24:28 +03:00
|
|
|
log.Debugf("[netlink] Socket created.\n")
|
2016-04-19 04:09:24 +03:00
|
|
|
return s, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Closes the socket.
|
|
|
|
func (s *socket) close() {
|
|
|
|
err := unix.Close(s.fd)
|
2016-06-06 06:58:47 +03:00
|
|
|
log.Debugf("[netlink] Socket closed, err=%v\n", err)
|
2016-04-19 04:09:24 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
// Sends a netlink message.
|
|
|
|
func (s *socket) send(msg *message) error {
|
|
|
|
msg.Seq = atomic.AddUint32(&s.seq, 1)
|
|
|
|
err := unix.Sendto(s.fd, msg.serialize(), 0, &s.sa)
|
2016-06-06 06:58:47 +03:00
|
|
|
log.Debugf("[netlink] Sent %+v, err=%v\n", *msg, err)
|
2016-04-19 04:09:24 +03:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2016-06-05 16:42:14 +03:00
|
|
|
// Sends a netlink message and blocks until its response is received.
|
|
|
|
func (s *socket) sendAndWaitForResponse(msg *message) ([]*message, error) {
|
2016-04-19 04:09:24 +03:00
|
|
|
s.Lock()
|
|
|
|
defer s.Unlock()
|
|
|
|
|
|
|
|
err := s.send(msg)
|
|
|
|
if err != nil {
|
2016-06-05 16:42:14 +03:00
|
|
|
return nil, err
|
2016-04-19 04:09:24 +03:00
|
|
|
}
|
|
|
|
|
2016-06-05 16:42:14 +03:00
|
|
|
return s.receiveResponse(msg)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Sends a netlink message and blocks until its ack is received.
|
|
|
|
func (s *socket) sendAndWaitForAck(msg *message) error {
|
|
|
|
_, err := s.sendAndWaitForResponse(msg)
|
|
|
|
return err
|
2016-04-19 04:09:24 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
// Receives a netlink message.
|
|
|
|
func (s *socket) receive() ([]syscall.NetlinkMessage, error) {
|
|
|
|
buffer := make([]byte, unix.Getpagesize())
|
|
|
|
n, _, err := unix.Recvfrom(s.fd, buffer, 0)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
if n < unix.NLMSG_HDRLEN {
|
|
|
|
return nil, fmt.Errorf("Invalid netlink message")
|
|
|
|
}
|
|
|
|
|
|
|
|
buffer = buffer[:n]
|
|
|
|
return syscall.ParseNetlinkMessage(buffer)
|
|
|
|
}
|
|
|
|
|
2016-06-05 16:42:14 +03:00
|
|
|
// Receives the response for the given sent message and returns the parsed message.
|
|
|
|
func (s *socket) receiveResponse(sent *message) ([]*message, error) {
|
|
|
|
var messages []*message
|
|
|
|
var multi, done bool
|
|
|
|
|
2016-04-19 04:09:24 +03:00
|
|
|
for {
|
2016-06-05 16:42:14 +03:00
|
|
|
// Receive all pending messages.
|
|
|
|
nlMsgs, err := s.receive()
|
2016-04-19 04:09:24 +03:00
|
|
|
if err != nil {
|
|
|
|
log.Printf("[netlink] Receive err=%v\n", err)
|
2016-06-05 16:42:14 +03:00
|
|
|
return messages, err
|
2016-04-19 04:09:24 +03:00
|
|
|
}
|
|
|
|
|
2016-06-05 16:42:14 +03:00
|
|
|
// Process received messages.
|
|
|
|
for _, nlMsg := range nlMsgs {
|
|
|
|
// Convert to message object.
|
|
|
|
msg := message{
|
|
|
|
NlMsghdr: unix.NlMsghdr{
|
|
|
|
Len: nlMsg.Header.Len,
|
|
|
|
Type: nlMsg.Header.Type,
|
|
|
|
Flags: nlMsg.Header.Flags,
|
|
|
|
Seq: nlMsg.Header.Seq,
|
|
|
|
Pid: nlMsg.Header.Pid,
|
|
|
|
},
|
|
|
|
data: nlMsg.Data,
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ignore if the message is not in response to the sent message.
|
|
|
|
if msg.Seq != sent.Seq || msg.Pid != sent.Pid {
|
|
|
|
log.Printf("[netlink] Ignoring unexpected message %+v\n", msg)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
// Return if this is an ack or an error message.
|
2016-04-19 04:09:24 +03:00
|
|
|
// An acknowledgement is an error message with error code set to
|
|
|
|
// zero, followed by the original request message header.
|
2016-06-05 16:42:14 +03:00
|
|
|
if msg.Type == unix.NLMSG_ERROR {
|
|
|
|
errCode := int32(encoder.Uint32(msg.data[0:4]))
|
|
|
|
if errCode == 0 {
|
2016-06-06 06:58:47 +03:00
|
|
|
log.Debugf("[netlink] Received %+v, ack\n", msg)
|
2016-06-05 16:42:14 +03:00
|
|
|
} else {
|
2016-04-19 04:09:24 +03:00
|
|
|
err = syscall.Errno(-errCode)
|
2016-06-05 16:42:14 +03:00
|
|
|
log.Printf("[netlink] Received %+v, err=%v\n", msg, err)
|
2016-04-19 04:09:24 +03:00
|
|
|
}
|
2016-06-05 16:42:14 +03:00
|
|
|
return nil, err
|
|
|
|
}
|
2016-04-19 04:09:24 +03:00
|
|
|
|
2016-06-05 16:42:14 +03:00
|
|
|
// Log response message.
|
2016-06-06 06:58:47 +03:00
|
|
|
log.Debugf("[netlink] Received %+v\n", msg)
|
2016-04-19 04:09:24 +03:00
|
|
|
|
2016-06-05 16:42:14 +03:00
|
|
|
// Parse body.
|
|
|
|
msg.payload = append(msg.payload, nil)
|
|
|
|
|
|
|
|
// Parse attributes.
|
|
|
|
// Ignore failures as not all messages have attributes.
|
|
|
|
nlAttrs, _ := syscall.ParseNetlinkRouteAttr(&nlMsg)
|
|
|
|
|
|
|
|
// Convert to attribute objects.
|
|
|
|
for _, nlAttr := range nlAttrs {
|
|
|
|
attr := attribute{
|
|
|
|
NlAttr: unix.NlAttr{
|
|
|
|
Len: nlAttr.Attr.Len,
|
|
|
|
Type: nlAttr.Attr.Type,
|
|
|
|
},
|
|
|
|
value: nlAttr.Value,
|
|
|
|
}
|
|
|
|
msg.payload = append(msg.payload, &attr)
|
|
|
|
}
|
|
|
|
|
|
|
|
multi = ((msg.Flags & unix.NLM_F_MULTI) != 0)
|
|
|
|
done = (msg.Type == unix.NLMSG_DONE)
|
|
|
|
|
|
|
|
// Exit if message completes a multipart response.
|
|
|
|
if multi && done {
|
|
|
|
break
|
2016-04-19 04:09:24 +03:00
|
|
|
}
|
2016-06-05 16:42:14 +03:00
|
|
|
|
|
|
|
messages = append(messages, &msg)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Exit if response is a single message,
|
|
|
|
// or a completed multipart message.
|
|
|
|
if !multi || done {
|
|
|
|
break
|
2016-04-19 04:09:24 +03:00
|
|
|
}
|
|
|
|
}
|
2016-06-05 16:42:14 +03:00
|
|
|
|
|
|
|
return messages, nil
|
2016-04-19 04:09:24 +03:00
|
|
|
}
|