diff --git a/README.md b/README.md index 477cba2..06663a1 100755 --- a/README.md +++ b/README.md @@ -14,6 +14,16 @@ libaudit-go can be used to build go applications which perform tasks similar to To get started see package documentation at [godoc](https://godoc.org/github.com/mozilla/libaudit-go). +For a simple example of usage, see the [auditprint](./auditprint/) tool included in this repository. + +```bash +sudo service stop auditd +go get -u github.com/mozilla/libaudit-go +cd $GOPATH/src/github.com/mozilla/libaudit-go +go install github.com/mozilla/libaudit-go/auditprint +sudo $GOPATH/bin/auditprint testdata/rules.json +``` + Some key functions are discussed in the overview section below. ## Overview diff --git a/audit_events.go b/audit_events.go index c463807..3ce8e43 100644 --- a/audit_events.go +++ b/audit_events.go @@ -31,6 +31,14 @@ type AuditEvent struct { // NewAuditEvent takes a NetlinkMessage passed from the netlink connection and parses the data // from the message header to return an AuditEvent type. +// +// Note that it is possible here that we don't have a full event to return. In some cases, a +// single audit event may be represented by multiple audit events from the kernel. This function +// looks after buffering partial fragments of a full event, and may only return the complete event +// once an AUDIT_EOE record has been recieved for the audit event. +// +// See https://www.redhat.com/archives/linux-audit/2016-January/msg00019.html for additional information +// on the behavior of this function. func NewAuditEvent(msg NetlinkMessage) (*AuditEvent, error) { x, err := ParseAuditEvent(string(msg.Data[:]), auditConstant(msg.Header.Type), true) if err != nil { @@ -40,7 +48,23 @@ func NewAuditEvent(msg NetlinkMessage) (*AuditEvent, error) { return nil, fmt.Errorf("unknown message type %d", msg.Header.Type) } - return x, nil + // Determine if the event type is one which the kernel is expected to send only a single + // packet for; in these cases we don't need to look into buffering it and can return the + // event immediately. + if auditConstant(msg.Header.Type) < AUDIT_SYSCALL || + auditConstant(msg.Header.Type) >= AUDIT_FIRST_ANOM_MSG { + return x, nil + } + + // If this is an EOE message, get the entire processed message and return it. + if auditConstant(msg.Header.Type) == AUDIT_EOE { + return bufferGet(x), nil + } + + // Otherwise we need to buffer this message. + bufferEvent(x) + + return nil, nil } // GetAuditEvents receives audit messages from the kernel and parses them into an AuditEvent. @@ -63,6 +87,9 @@ func GetAuditEvents(s Netlink, cb EventCallback) { } } else { nae, err := NewAuditEvent(msg) + if nae == nil { + continue + } cb(nae, err) } } @@ -125,6 +152,9 @@ func GetAuditMessages(s Netlink, cb EventCallback, done *chan bool) { } } else { nae, err := NewAuditEvent(msg) + if nae == nil { + continue + } cb(nae, err) } } diff --git a/auditprint/auditprint.go b/auditprint/auditprint.go new file mode 100644 index 0000000..0d9b386 --- /dev/null +++ b/auditprint/auditprint.go @@ -0,0 +1,105 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// auditprint is a simple command line tool that loads an audit rule set from a JSON file, +// applies it to the current kernel and begins printing any audit event the kernel sends in +// JSON format. +package main + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + + "github.com/mozilla/libaudit-go" +) + +func auditProc(e *libaudit.AuditEvent, err error) { + if err != nil { + // See if the error is libaudit.ErrorAuditParse, if so convert and also display + // the audit record we could not parse + if nerr, ok := err.(libaudit.ErrorAuditParse); ok { + fmt.Printf("parser error: %v: %v\n", nerr, nerr.Raw) + } else { + fmt.Printf("callback received error: %v\n", err) + } + return + } + // Marshal the event to JSON and print + buf, err := json.Marshal(e) + if err != nil { + fmt.Printf("callback was unable to marshal event: %v\n", err) + return + } + fmt.Printf("%v\n", string(buf)) +} + +func main() { + s, err := libaudit.NewNetlinkConnection() + if err != nil { + fmt.Fprintf(os.Stderr, "NetNetlinkConnection: %v\n", err) + os.Exit(1) + } + + if len(os.Args) != 2 { + fmt.Printf("usage: %v path_to_rules.json\n", os.Args[0]) + os.Exit(0) + } + + err = libaudit.AuditSetEnabled(s, true) + if err != nil { + fmt.Fprintf(os.Stderr, "AuditSetEnabled: %v\n", err) + os.Exit(1) + } + + err = libaudit.AuditSetPID(s, os.Getpid()) + if err != nil { + fmt.Fprintf(os.Stderr, "AuditSetPid: %v\n", err) + os.Exit(1) + } + err = libaudit.AuditSetRateLimit(s, 1000) + if err != nil { + fmt.Fprintf(os.Stderr, "AuditSetRateLimit: %v\n", err) + os.Exit(1) + } + err = libaudit.AuditSetBacklogLimit(s, 250) + if err != nil { + fmt.Fprintf(os.Stderr, "AuditSetBacklogLimit: %v\n", err) + os.Exit(1) + } + + var ar libaudit.AuditRules + buf, err := ioutil.ReadFile(os.Args[1]) + if err != nil { + fmt.Fprintf(os.Stderr, "ReadFile: %v\n", err) + os.Exit(1) + } + // Make sure we can unmarshal the rules JSON to validate it is the correct + // format + err = json.Unmarshal(buf, &ar) + if err != nil { + fmt.Fprintf(os.Stderr, "Unmarshaling rules JSON: %v\n", err) + os.Exit(1) + } + + // Remove current rule set and send rules to the kernel + err = libaudit.DeleteAllRules(s) + if err != nil { + fmt.Fprintf(os.Stderr, "DeleteAllRules: %v\n", err) + os.Exit(1) + } + warnings, err := libaudit.SetRules(s, buf) + if err != nil { + fmt.Fprintf(os.Stderr, "SetRules: %v\n", err) + os.Exit(1) + } + // Print any warnings we got back but still continue + for _, x := range warnings { + fmt.Fprintf(os.Stderr, "ruleset warning: %v\n", x) + } + + doneCh := make(chan bool, 1) + libaudit.GetAuditMessages(s, auditProc, &doneCh) +} diff --git a/buffer.go b/buffer.go new file mode 100644 index 0000000..117d718 --- /dev/null +++ b/buffer.go @@ -0,0 +1,58 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package libaudit + +import ( + "strconv" +) + +var bufferMap map[uint64][]*AuditEvent + +// bufferEvent buffers an incoming audit event which contains partial record informatioa. +func bufferEvent(a *AuditEvent) { + if bufferMap == nil { + bufferMap = make(map[uint64][]*AuditEvent) + } + + serial, err := strconv.ParseUint(a.Serial, 10, 64) + if err != nil { + return + } + if _, ok := bufferMap[serial]; !ok { + bufferMap[serial] = make([]*AuditEvent, 5) + } + bufferMap[serial] = append(bufferMap[serial], a) +} + +// bufferGet returns the complete audit event from the buffer, given the AUDIT_EOE event a. +func bufferGet(a *AuditEvent) *AuditEvent { + serial, err := strconv.ParseUint(a.Serial, 10, 64) + if err != nil { + return nil + } + + var ( + bm []*AuditEvent + ok bool + ) + if bm, ok = bufferMap[serial]; !ok { + return nil + } + rlen := len(a.Raw) + for i := range bm { + if bm[i] == nil { + continue + } + for k, v := range bm[i].Data { + a.Data[k] = v + } + if len(bm[i].Raw) > rlen { + a.Raw += " " + bm[i].Raw[rlen:] + } + } + + delete(bufferMap, serial) + return a +} diff --git a/libaudit.go b/libaudit.go index 9e786f3..bea2ecb 100644 --- a/libaudit.go +++ b/libaudit.go @@ -175,6 +175,7 @@ func parseAuditNetlinkMessage(b []byte) (ret []NetlinkMessage, err error) { var ( m NetlinkMessage ) + m.Header.Len, b, err = netlinkPopuint32(b) if err != nil { return @@ -182,6 +183,7 @@ func parseAuditNetlinkMessage(b []byte) (ret []NetlinkMessage, err error) { // Determine our alignment size given the reported header length alignbounds := nlmAlignOf(int(m.Header.Len)) padding := alignbounds - int(m.Header.Len) + // Subtract 4 from alignbounds here to account for already having popped 4 bytes // off the input buffer if len(b) < alignbounds-4 { @@ -208,9 +210,27 @@ func parseAuditNetlinkMessage(b []byte) (ret []NetlinkMessage, err error) { if err != nil { return ret, err } - datalen := m.Header.Len - syscall.NLMSG_HDRLEN - m.Data = b[:datalen] - b = b[int(datalen)+padding:] + // Determine how much data we want to read here; if this isn't NLM_F_MULTI, we'd + // typically want to read m.Header.Len bytes (the length of the payload indicated in + // the netlink header. + // + // However, this isn't always the case. Depending on what is generating the audit + // message (e.g., via audit_log_end) the kernel does not include the netlink header + // size in the submitted audit message. So, we just read whatever is left in the buffer + // we have if this isn't multipart. + // + // Additionally, it seems like there are also a few messages types where the netlink paylaod + // value is inaccurate and can't be relied upon. + // + // XXX Just consuming the rest of the buffer based on the event type might be a better + // approach here. + if !multi { + m.Data = b + } else { + datalen := m.Header.Len - syscall.NLMSG_HDRLEN + m.Data = b[:datalen] + b = b[int(datalen)+padding:] + } ret = append(ret, m) if !multi { break diff --git a/parser.go b/parser.go index a8dcc20..a1624b1 100644 --- a/parser.go +++ b/parser.go @@ -17,6 +17,28 @@ type record struct { a1 int } +// ErrorAuditParse is an implementation of the error interface that is returned by +// ParseAuditEvent. msg will contain a description of the error, and the raw audit event +// which failed parsing is returned in raw for inspection by the calling program. +type ErrorAuditParse struct { + Msg string + Raw string +} + +// Error returns a string representation of ErrorAuditParse e +func (e ErrorAuditParse) Error() string { + return e.Msg +} + +// newErrorAuditParse returns a new ErrorAuditParse type with the fields populated +func newErrorAuditParse(raw string, f string, v ...interface{}) ErrorAuditParse { + ret := ErrorAuditParse{ + Raw: raw, + Msg: fmt.Sprintf(f, v...), + } + return ret +} + // ParseAuditEvent parses an incoming audit message from kernel and returns an AuditEvent. // // msgType is supposed to come from the calling function which holds the msg header indicating header @@ -36,11 +58,11 @@ func ParseAuditEvent(str string, msgType auditConstant, interpret bool) (*AuditE if strings.HasPrefix(str, "audit(") { str = str[6:] } else { - return nil, fmt.Errorf("malformed audit message") + return nil, newErrorAuditParse(event.Raw, "malformed, missing audit prefix") } index := strings.Index(str, ":") if index == -1 { - return nil, fmt.Errorf("malformed audit message") + return nil, newErrorAuditParse(event.Raw, "malformed, can't locate start of fields") } // determine timestamp @@ -49,13 +71,13 @@ func ParseAuditEvent(str string, msgType auditConstant, interpret bool) (*AuditE str = str[index+1:] index = strings.Index(str, ")") if index == -1 { - return nil, fmt.Errorf("malformed audit message") + return nil, newErrorAuditParse(event.Raw, "malformed, can't locate end of prefix") } serial := str[:index] if strings.HasPrefix(str, serial+"): ") { str = str[index+3:] } else { - return nil, fmt.Errorf("malformed audit message") + return nil, newErrorAuditParse(event.Raw, "malformed, prefix termination unexpected") } var ( @@ -87,7 +109,7 @@ func ParseAuditEvent(str string, msgType auditConstant, interpret bool) (*AuditE var err error value, err = interpretField(key, value, msgType, r) if err != nil { - return nil, err + return nil, newErrorAuditParse(event.Raw, "interpretField: %v", err) } } m[key] = value @@ -117,7 +139,7 @@ func ParseAuditEvent(str string, msgType auditConstant, interpret bool) (*AuditE var err error value, err = interpretField(key, value, msgType, r) if err != nil { - return nil, err + return nil, newErrorAuditParse(event.Raw, "interpretField: %v", err) } } m[key] = value @@ -142,7 +164,7 @@ func ParseAuditEvent(str string, msgType auditConstant, interpret bool) (*AuditE var err error value, err = interpretField(key, value, msgType, r) if err != nil { - return nil, err + return nil, newErrorAuditParse(event.Raw, "interpretField: %v", err) } } m[key] = value @@ -155,7 +177,7 @@ func ParseAuditEvent(str string, msgType auditConstant, interpret bool) (*AuditE var err error value, err = interpretField(key, value, msgType, r) if err != nil { - return nil, err + return nil, newErrorAuditParse(event.Raw, "interpretField: %v", err) } } m[key] = value @@ -206,7 +228,7 @@ func ParseAuditEvent(str string, msgType auditConstant, interpret bool) (*AuditE var err error value, err = interpretField(key, value, msgType, r) if err != nil { - return nil, err + return nil, newErrorAuditParse(event.Raw, "interpretField: %v", err) } } m[key] = value diff --git a/parser_test.go b/parser_test.go index dd18e1a..8aaccc5 100644 --- a/parser_test.go +++ b/parser_test.go @@ -280,17 +280,18 @@ func TestMalformedPrefix(t *testing.T) { tmsg := []struct { msg string msgType auditConstant + err string }{ - {"xyzabc", AUDIT_AVC}, - {`audit(1464163771`, AUDIT_AVC}, - {`audit(1464176620.068:1445`, AUDIT_AVC}, + {"xyzabc", AUDIT_AVC, "malformed, missing audit prefix"}, + {`audit(1464163771`, AUDIT_AVC, "malformed, can't locate start of fields"}, + {`audit(1464176620.068:1445`, AUDIT_AVC, "malformed, can't locate end of prefix"}, } for _, m := range tmsg { _, err := ParseAuditEvent(m.msg, m.msgType, false) if err == nil { t.Fatalf("ParseAuditEvent should have failed on %q", m.msg) } - if err.Error() != "malformed audit message" { + if err.Error() != m.err { t.Fatalf("ParseAuditEvent failed, but error %q was unexpected", err) } }