From 7e379daf96ce737485c479aa19b12dd96b6761b5 Mon Sep 17 00:00:00 2001 From: Tatiana Bradley Date: Mon, 18 Dec 2023 12:41:25 -0500 Subject: [PATCH] internal/{cveutils,worker}: move CVE triage to cveutils Move existing logic to triage v4 CVEs to the cveutils package. This will make it easier to add tests and implement triage for v5 CVEs. Change-Id: I4872af391a33500dd7236795a910ad3a6998b5e0 Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/550857 LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil --- cmd/worker/main.go | 3 +- internal/{worker => cveutils}/paths.go | 2 +- internal/{worker => cveutils}/paths_test.go | 2 +- internal/cveutils/pkgsite.go | 99 +++++++++++++++ internal/{worker => cveutils}/triage.go | 121 +++++-------------- internal/{worker => cveutils}/triage_test.go | 67 ++++------ internal/worker/update.go | 21 ++-- internal/worker/update_test.go | 10 +- internal/worker/worker.go | 5 +- 9 files changed, 176 insertions(+), 154 deletions(-) rename internal/{worker => cveutils}/paths.go (99%) rename internal/{worker => cveutils}/paths_test.go (99%) create mode 100644 internal/cveutils/pkgsite.go rename internal/{worker => cveutils}/triage.go (61%) rename internal/{worker => cveutils}/triage_test.go (73%) diff --git a/cmd/worker/main.go b/cmd/worker/main.go index adbf91b9..db173aa7 100644 --- a/cmd/worker/main.go +++ b/cmd/worker/main.go @@ -21,6 +21,7 @@ import ( "time" "golang.org/x/vulndb/internal/cvelistrepo" + "golang.org/x/vulndb/internal/cveutils" "golang.org/x/vulndb/internal/ghsa" "golang.org/x/vulndb/internal/gitrepo" "golang.org/x/vulndb/internal/issues" @@ -237,7 +238,7 @@ func populateKnownModules(filename string) error { if err := scan.Err(); err != nil { return err } - worker.SetKnownModules(mods) + cveutils.SetKnownModules(mods) fmt.Printf("set %d known modules\n", len(mods)) return nil } diff --git a/internal/worker/paths.go b/internal/cveutils/paths.go similarity index 99% rename from internal/worker/paths.go rename to internal/cveutils/paths.go index 08400534..006de1eb 100644 --- a/internal/worker/paths.go +++ b/internal/cveutils/paths.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package worker +package cveutils import ( "path" diff --git a/internal/worker/paths_test.go b/internal/cveutils/paths_test.go similarity index 99% rename from internal/worker/paths_test.go rename to internal/cveutils/paths_test.go index ce027e87..0bfec140 100644 --- a/internal/worker/paths_test.go +++ b/internal/cveutils/paths_test.go @@ -5,7 +5,7 @@ //go:build go1.17 // +build go1.17 -package worker +package cveutils import ( "testing" diff --git a/internal/cveutils/pkgsite.go b/internal/cveutils/pkgsite.go new file mode 100644 index 00000000..993d78f9 --- /dev/null +++ b/internal/cveutils/pkgsite.go @@ -0,0 +1,99 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cveutils + +import ( + "context" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "golang.org/x/time/rate" + "golang.org/x/vulndb/internal/worker/log" +) + +// Limit pkgsite requests to this many per second. +const pkgsiteQPS = 5 + +var ( + // The limiter used to throttle pkgsite requests. + // The second argument to rate.NewLimiter is the burst, which + // basically lets you exceed the rate briefly. + pkgsiteRateLimiter = rate.NewLimiter(rate.Every(time.Duration(1000/float64(pkgsiteQPS))*time.Millisecond), 3) + + // Cache of module paths already seen. + seenModulePath = map[string]bool{} + // Does seenModulePath contain all known modules? + cacheComplete = false +) + +// SetKnownModules provides a list of all known modules, +// so that no requests need to be made to pkg.go.dev. +func SetKnownModules(mods []string) { + for _, m := range mods { + seenModulePath[m] = true + } + cacheComplete = true +} + +var pkgsiteURL = "https://pkg.go.dev" + +// knownToPkgsite reports whether pkgsite knows that modulePath actually refers +// to a module. +func knownToPkgsite(ctx context.Context, baseURL, modulePath string) (bool, error) { + // If we've seen it before, no need to call. + if b, ok := seenModulePath[modulePath]; ok { + return b, nil + } + if cacheComplete { + return false, nil + } + // Pause to maintain a max QPS. + if err := pkgsiteRateLimiter.Wait(ctx); err != nil { + return false, err + } + start := time.Now() + + url := baseURL + "/mod/" + modulePath + res, err := http.Head(url) + var status string + if err == nil { + status = strconv.Quote(res.Status) + } + log.With( + "latency", time.Since(start), + "status", status, + "error", err, + ).Debugf(ctx, "checked if %s is known to pkgsite at HEAD", url) + if err != nil { + return false, err + } + known := res.StatusCode == http.StatusOK + seenModulePath[modulePath] = known + return known, nil +} + +// GetPkgsiteURL returns a URL to either a fake server or the real pkg.go.dev, +// depending on the useRealPkgsite value. +// +// For testing. +func GetPkgsiteURL(t *testing.T, useRealPkgsite bool) string { + if useRealPkgsite { + return pkgsiteURL + } + // Start a test server that recognizes anything from golang.org and bitbucket.org/foo/bar/baz. + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + modulePath := strings.TrimPrefix(r.URL.Path, "/mod/") + if !strings.HasPrefix(modulePath, "golang.org/") && + !strings.HasPrefix(modulePath, "bitbucket.org/foo/bar/baz") { + http.Error(w, "unknown", http.StatusNotFound) + } + })) + t.Cleanup(s.Close) + return s.URL +} diff --git a/internal/worker/triage.go b/internal/cveutils/triage.go similarity index 61% rename from internal/worker/triage.go rename to internal/cveutils/triage.go index 73099148..a47f7820 100644 --- a/internal/worker/triage.go +++ b/internal/cveutils/triage.go @@ -2,20 +2,16 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package worker +package cveutils import ( "context" "errors" "fmt" - "net/http" "net/url" "regexp" - "strconv" "strings" - "time" - "golang.org/x/time/rate" "golang.org/x/vulndb/internal/cveschema" "golang.org/x/vulndb/internal/derrors" "golang.org/x/vulndb/internal/ghsa" @@ -40,7 +36,7 @@ var stdlibReferenceDataKeywords = []string{ const unknownPath = "Path is unknown" // TriageCVE reports whether the CVE refers to a Go module. -func TriageCVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (_ *triageResult, err error) { +func TriageCVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (_ *TriageResult, err error) { defer derrors.Wrap(&err, "triageCVE(%q)", c.ID) switch c.DataVersion { case "4.0": @@ -51,10 +47,10 @@ func TriageCVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (_ *tri } } -type triageResult struct { - modulePath string - packagePath string - reason string +type TriageResult struct { + ModulePath string + PackagePath string + Reason string } // gopkgHosts are hostnames for popular Go package websites. @@ -83,7 +79,7 @@ var notGoModules = map[string]bool{ } // triageV4CVE triages a CVE following schema v4.0 and returns the result. -func triageV4CVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (result *triageResult, err error) { +func triageV4CVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (result *TriageResult, err error) { defer derrors.Wrap(&err, "triageV4CVE(ctx, %q, %q)", c.ID, pkgsiteURL) defer func() { if err != nil { @@ -94,7 +90,7 @@ func triageV4CVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (resu log.Debugf(ctx, "%s: not Go vuln", msg) return } - log.Debugf(ctx, "%s: is Go vuln:\n%s", msg, result.reason) + log.Debugf(ctx, "%s: is Go vuln:\n%s", msg, result.Reason) }() for _, r := range c.References.Data { if r.URL == "" { @@ -106,24 +102,24 @@ func triageV4CVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (resu } if strings.Contains(r.URL, "golang.org/pkg") { mp := strings.TrimPrefix(refURL.Path, "/pkg/") - return &triageResult{ - packagePath: mp, - modulePath: stdlib.ModulePath, - reason: fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp), + return &TriageResult{ + PackagePath: mp, + ModulePath: stdlib.ModulePath, + Reason: fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp), }, nil } if gopkgHosts[refURL.Host] { mp := strings.TrimPrefix(refURL.Path, "/") if stdlib.Contains(mp) { - return &triageResult{ - packagePath: mp, - modulePath: stdlib.ModulePath, - reason: fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp), + return &TriageResult{ + PackagePath: mp, + ModulePath: stdlib.ModulePath, + Reason: fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp), }, nil } - return &triageResult{ - modulePath: mp, - reason: fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp), + return &TriageResult{ + ModulePath: mp, + Reason: fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp), }, nil } modpaths := candidateModulePaths(refURL.Host + refURL.Path) @@ -137,9 +133,9 @@ func triageV4CVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (resu } if known { u := pkgsiteURL + "/" + mp - return &triageResult{ - modulePath: mp, - reason: fmt.Sprintf("Reference data URL %q contains path %q; %q returned a status 200", r.URL, mp, u), + return &TriageResult{ + ModulePath: mp, + Reason: fmt.Sprintf("Reference data URL %q contains path %q; %q returned a status 200", r.URL, mp, u), }, nil } } @@ -151,9 +147,9 @@ func triageV4CVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (resu // Example CVE containing snyk.io URL: // https://github.com/CVEProject/cvelist/blob/899bba20d62eb73e04d1841a5ff04cd6225e1618/2020/7xxx/CVE-2020-7668.json#L52. if strings.Contains(r.URL, snykIdentifier) { - return &triageResult{ - modulePath: unknownPath, - reason: fmt.Sprintf("Reference data URL %q contains %q", r.URL, snykIdentifier), + return &TriageResult{ + ModulePath: unknownPath, + Reason: fmt.Sprintf("Reference data URL %q contains %q", r.URL, snykIdentifier), }, nil } @@ -161,9 +157,9 @@ func triageV4CVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (resu // project. for _, k := range stdlibReferenceDataKeywords { if strings.Contains(r.URL, k) { - return &triageResult{ - modulePath: stdlib.ModulePath, - reason: fmt.Sprintf("Reference data URL %q contains %q", r.URL, k), + return &TriageResult{ + ModulePath: stdlib.ModulePath, + Reason: fmt.Sprintf("Reference data URL %q contains %q", r.URL, k), }, nil } } @@ -173,69 +169,10 @@ func triageV4CVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (resu var ghsaRegex = regexp.MustCompile(ghsa.Regex) -func getAliasGHSAs(c *cveschema.CVE) []string { +func GetAliasGHSAs(c *cveschema.CVE) []string { var ghsas []string for _, r := range c.References.Data { ghsas = append(ghsas, ghsaRegex.FindAllString(r.URL, 1)...) } return ghsas } - -// Limit pkgsite requests to this many per second. -const pkgsiteQPS = 5 - -var ( - // The limiter used to throttle pkgsite requests. - // The second argument to rate.NewLimiter is the burst, which - // basically lets you exceed the rate briefly. - pkgsiteRateLimiter = rate.NewLimiter(rate.Every(time.Duration(1000/float64(pkgsiteQPS))*time.Millisecond), 3) - - // Cache of module paths already seen. - seenModulePath = map[string]bool{} - // Does seenModulePath contain all known modules? - cacheComplete = false -) - -// SetKnownModules provides a list of all known modules, -// so that no requests need to be made to pkg.go.dev. -func SetKnownModules(mods []string) { - for _, m := range mods { - seenModulePath[m] = true - } - cacheComplete = true -} - -// knownToPkgsite reports whether pkgsite knows that modulePath actually refers -// to a module. -func knownToPkgsite(ctx context.Context, baseURL, modulePath string) (bool, error) { - // If we've seen it before, no need to call. - if b, ok := seenModulePath[modulePath]; ok { - return b, nil - } - if cacheComplete { - return false, nil - } - // Pause to maintain a max QPS. - if err := pkgsiteRateLimiter.Wait(ctx); err != nil { - return false, err - } - start := time.Now() - - url := baseURL + "/mod/" + modulePath - res, err := http.Head(url) - var status string - if err == nil { - status = strconv.Quote(res.Status) - } - log.With( - "latency", time.Since(start), - "status", status, - "error", err, - ).Debugf(ctx, "checked if %s is known to pkgsite at HEAD", url) - if err != nil { - return false, err - } - known := res.StatusCode == http.StatusOK - seenModulePath[modulePath] = known - return known, nil -} diff --git a/internal/worker/triage_test.go b/internal/cveutils/triage_test.go similarity index 73% rename from internal/worker/triage_test.go rename to internal/cveutils/triage_test.go index 52e10cb2..d8da5c58 100644 --- a/internal/worker/triage_test.go +++ b/internal/cveutils/triage_test.go @@ -5,14 +5,11 @@ //go:build go1.17 // +build go1.17 -package worker +package cveutils import ( "context" "flag" - "net/http" - "net/http/httptest" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -25,12 +22,12 @@ var usePkgsite = flag.Bool("pkgsite", false, "use pkg.go.dev for tests") func TestTriageV4CVE(t *testing.T) { ctx := context.Background() - url := getPkgsiteURL(t) + url := GetPkgsiteURL(t, *usePkgsite) for _, test := range []struct { name string in *cveschema.CVE - want *triageResult + want *TriageResult }{ { "repo path is unknown Go standard library", @@ -41,8 +38,8 @@ func TestTriageV4CVE(t *testing.T) { }, }, }, - &triageResult{ - modulePath: stdlib.ModulePath, + &TriageResult{ + ModulePath: stdlib.ModulePath, }, }, { @@ -54,9 +51,9 @@ func TestTriageV4CVE(t *testing.T) { }, }, }, - &triageResult{ - modulePath: stdlib.ModulePath, - packagePath: "net/http", + &TriageResult{ + ModulePath: stdlib.ModulePath, + PackagePath: "net/http", }, }, { @@ -69,8 +66,8 @@ func TestTriageV4CVE(t *testing.T) { }, }, }, - &triageResult{ - modulePath: "golang.org/x/mod", + &TriageResult{ + ModulePath: "golang.org/x/mod", }, }, { @@ -82,8 +79,8 @@ func TestTriageV4CVE(t *testing.T) { }, }, }, - &triageResult{ - modulePath: "golang.org/x/mod", + &TriageResult{ + ModulePath: "golang.org/x/mod", }, }, { @@ -95,9 +92,9 @@ func TestTriageV4CVE(t *testing.T) { }, }, }, - &triageResult{ - modulePath: stdlib.ModulePath, - packagePath: "net/http", + &TriageResult{ + ModulePath: stdlib.ModulePath, + PackagePath: "net/http", }, }, { @@ -120,8 +117,8 @@ func TestTriageV4CVE(t *testing.T) { }, }, }, - &triageResult{ - modulePath: "golang.org/x/exp/event", + &TriageResult{ + ModulePath: "golang.org/x/exp/event", }, }, { @@ -144,8 +141,8 @@ func TestTriageV4CVE(t *testing.T) { }, }, }, - &triageResult{ - modulePath: unknownPath, + &TriageResult{ + ModulePath: unknownPath, }, }, } { @@ -156,8 +153,8 @@ func TestTriageV4CVE(t *testing.T) { t.Fatal(err) } if diff := cmp.Diff(test.want, got, - cmp.AllowUnexported(triageResult{}), - cmpopts.IgnoreFields(triageResult{}, "reason")); diff != "" { + cmp.AllowUnexported(TriageResult{}), + cmpopts.IgnoreFields(TriageResult{}, "Reason")); diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } }) @@ -168,7 +165,7 @@ func TestKnownToPkgsite(t *testing.T) { ctx := context.Background() const validModule = "golang.org/x/mod" - url := getPkgsiteURL(t) + url := GetPkgsiteURL(t, *usePkgsite) for _, test := range []struct { in string @@ -199,25 +196,7 @@ func TestGetAliasGHSAs(t *testing.T) { }, } want := "GHSA-xxxx-yyyy-0000" - if got := getAliasGHSAs(cve); got[0] != want { + if got := GetAliasGHSAs(cve); got[0] != want { t.Errorf("getAliasGHSAs: got %s, want %s", got, want) } } - -// getPkgsiteURL returns a URL to either a fake server or the real pkg.go.dev, -// depending on the usePkgsite flag. -func getPkgsiteURL(t *testing.T) string { - if *usePkgsite { - return pkgsiteURL - } - // Start a test server that recognizes anything from golang.org and bitbucket.org/foo/bar/baz. - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - modulePath := strings.TrimPrefix(r.URL.Path, "/mod/") - if !strings.HasPrefix(modulePath, "golang.org/") && - !strings.HasPrefix(modulePath, "bitbucket.org/foo/bar/baz") { - http.Error(w, "unknown", http.StatusNotFound) - } - })) - t.Cleanup(s.Close) - return s.URL -} diff --git a/internal/worker/update.go b/internal/worker/update.go index 9cbe4d5a..05467fe0 100644 --- a/internal/worker/update.go +++ b/internal/worker/update.go @@ -16,6 +16,7 @@ import ( "github.com/go-git/go-git/v5/plumbing/object" "golang.org/x/vulndb/internal/cvelistrepo" "golang.org/x/vulndb/internal/cveschema" + "golang.org/x/vulndb/internal/cveutils" "golang.org/x/vulndb/internal/derrors" "golang.org/x/vulndb/internal/ghsa" "golang.org/x/vulndb/internal/observe" @@ -26,7 +27,7 @@ import ( // A triageFunc triages a CVE: it decides whether an issue needs to be filed. // If so, it returns a non-empty string indicating the possibly // affected module. -type triageFunc func(*cveschema.CVE) (*triageResult, error) +type triageFunc func(*cveschema.CVE) (*cveutils.TriageResult, error) // A cveUpdater performs an update operation on the DB. type cveUpdater struct { @@ -260,7 +261,7 @@ func (u *cveUpdater) updateBatch(ctx context.Context, batch []cvelistrepo.File) // worker has already handled, and returns the appropriate triage state // based on this. func checkForAliases(cve *cveschema.CVE, tx store.Transaction) (store.TriageState, error) { - for _, ghsaID := range getAliasGHSAs(cve) { + for _, ghsaID := range cveutils.GetAliasGHSAs(cve) { ghsa, err := tx.GetGHSARecord(ghsaID) if err != nil { return "", err @@ -282,7 +283,7 @@ func (u *cveUpdater) handleCVE(f cvelistrepo.File, old *store.CVERecord, tx stor if err := cvelistrepo.Parse(u.repo, f, cve); err != nil { return nil, false, err } - var result *triageResult + var result *cveutils.TriageResult if cve.State == cveschema.StatePublic && !u.knownIDs[cve.ID] { c := cve // If a false positive has changed, we only care about @@ -309,9 +310,9 @@ func (u *cveUpdater) handleCVE(f cvelistrepo.File, old *store.CVERecord, tx stor return nil, false, err } cr.TriageState = triageState - cr.Module = result.modulePath - cr.Package = result.packagePath - cr.TriageStateReason = result.reason + cr.Module = result.ModulePath + cr.Package = result.PackagePath + cr.TriageStateReason = result.Reason cr.CVE = cve case u.knownIDs[cve.ID]: cr.TriageState = store.TriageStateHasVuln @@ -332,9 +333,9 @@ func (u *cveUpdater) handleCVE(f cvelistrepo.File, old *store.CVERecord, tx stor if result != nil { // Didn't need an issue before, does now. mod.TriageState = store.TriageStateNeedsIssue - mod.Module = result.modulePath - mod.Package = result.packagePath - mod.TriageStateReason = result.reason + mod.Module = result.ModulePath + mod.Package = result.PackagePath + mod.TriageStateReason = result.Reason mod.CVE = cve } // Else don't change the triage state, but we still want @@ -355,7 +356,7 @@ func (u *cveUpdater) handleCVE(f cvelistrepo.File, old *store.CVERecord, tx stor mod.TriageState = store.TriageStateUpdatedSinceIssueCreation var mp string if result != nil { - mp = result.modulePath + mp = result.ModulePath } mod.TriageStateReason = fmt.Sprintf("CVE changed; affected module = %q", mp) case store.TriageStateAlias: diff --git a/internal/worker/update_test.go b/internal/worker/update_test.go index 9490ef99..63d0f4bf 100644 --- a/internal/worker/update_test.go +++ b/internal/worker/update_test.go @@ -9,6 +9,7 @@ package worker import ( "context" + "flag" "testing" "time" @@ -18,11 +19,14 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/vulndb/internal/cvelistrepo" "golang.org/x/vulndb/internal/cveschema" + "golang.org/x/vulndb/internal/cveutils" "golang.org/x/vulndb/internal/ghsa" "golang.org/x/vulndb/internal/gitrepo" "golang.org/x/vulndb/internal/worker/store" ) +var usePkgsite = flag.Bool("pkgsite", false, "use pkg.go.dev for tests") + const clearString = "**CLEAR**" var clearCVE = &cveschema.CVE{} @@ -90,9 +94,9 @@ func TestDoUpdate(t *testing.T) { t.Fatal(err) } commit := headCommit(t, repo) - purl := getPkgsiteURL(t) - needsIssue := func(cve *cveschema.CVE) (*triageResult, error) { - return TriageCVE(ctx, cve, purl) + purl := cveutils.GetPkgsiteURL(t, *usePkgsite) + needsIssue := func(cve *cveschema.CVE) (*cveutils.TriageResult, error) { + return cveutils.TriageCVE(ctx, cve, purl) } commitHash := commit.Hash.String() diff --git a/internal/worker/worker.go b/internal/worker/worker.go index c79d8dca..202d5d09 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -22,6 +22,7 @@ import ( "golang.org/x/time/rate" "golang.org/x/vulndb/internal/cvelistrepo" "golang.org/x/vulndb/internal/cveschema" + "golang.org/x/vulndb/internal/cveutils" "golang.org/x/vulndb/internal/derrors" "golang.org/x/vulndb/internal/ghsa" "golang.org/x/vulndb/internal/gitrepo" @@ -74,8 +75,8 @@ func UpdateCVEsAtCommit(ctx context.Context, repoPath, commitHashString string, if err != nil { return err } - u := newCVEUpdater(repo, commit, st, knownVulnIDs, func(cve *cveschema.CVE) (*triageResult, error) { - return TriageCVE(ctx, cve, pkgsiteURL) + u := newCVEUpdater(repo, commit, st, knownVulnIDs, func(cve *cveschema.CVE) (*cveutils.TriageResult, error) { + return cveutils.TriageCVE(ctx, cve, pkgsiteURL) }) _, err = u.update(ctx) return err