diff --git a/internal/scan/binary.go b/internal/scan/binary.go index cfee2c7..37c6601 100644 --- a/internal/scan/binary.go +++ b/internal/scan/binary.go @@ -9,20 +9,18 @@ package scan import ( "context" - "fmt" "os" - "strings" - "unicode" "golang.org/x/vuln/internal/client" + "golang.org/x/vuln/internal/derrors" "golang.org/x/vuln/internal/govulncheck" - "golang.org/x/vuln/internal/osv" "golang.org/x/vuln/internal/vulncheck" ) // runBinary detects presence of vulnerable symbols in an executable. -func runBinary(ctx context.Context, handler govulncheck.Handler, cfg *config, client *client.Client) error { - var exe *os.File +func runBinary(ctx context.Context, handler govulncheck.Handler, cfg *config, client *client.Client) (err error) { + defer derrors.Wrap(&err, "govulncheck") + exe, err := os.Open(cfg.patterns[0]) if err != nil { return err @@ -33,110 +31,5 @@ func runBinary(ctx context.Context, handler govulncheck.Handler, cfg *config, cl if err := handler.Progress(p); err != nil { return err } - vr, err := vulncheck.Binary(ctx, exe, &cfg.Config, client) - if err != nil { - return fmt.Errorf("govulncheck: %v", err) - } - callstacks := binaryCallstacks(vr) - return emitBinaryResult(handler, vr, callstacks) -} - -func emitBinaryResult(handler govulncheck.Handler, vr *vulncheck.Result, callstacks map[*vulncheck.Vuln]vulncheck.CallStack) error { - osvs := map[string]*osv.Entry{} - // first deal with all the affected vulnerabilities - emitted := map[string]bool{} - seen := map[string]bool{} - emitFinding := func(finding *govulncheck.Finding) error { - if !seen[finding.OSV] { - seen[finding.OSV] = true - if err := handler.OSV(osvs[finding.OSV]); err != nil { - return err - } - } - return handler.Finding(finding) - } - - for _, vv := range vr.Vulns { - osvs[vv.OSV.ID] = vv.OSV - fixed := vulncheck.FixedVersion(vulncheck.ModPath(vv.ImportSink.Module), vulncheck.ModVersion(vv.ImportSink.Module), vv.OSV.Affected) - stack := callstacks[vv] - if stack == nil { - continue - } - emitted[vv.OSV.ID] = true - emitFinding(&govulncheck.Finding{ - OSV: vv.OSV.ID, - FixedVersion: fixed, - Trace: tracefromEntries(stack), - }) - } - for _, vv := range vr.Vulns { - if emitted[vv.OSV.ID] { - continue - } - stacks := callstacks[vv] - if len(stacks) != 0 { - continue - } - emitted[vv.OSV.ID] = true - emitFinding(&govulncheck.Finding{ - OSV: vv.OSV.ID, - FixedVersion: vulncheck.FixedVersion(vulncheck.ModPath(vv.ImportSink.Module), vulncheck.ModVersion(vv.ImportSink.Module), vv.OSV.Affected), - Trace: []*govulncheck.Frame{frameFromPackage(vv.ImportSink)}, - }) - } - return nil -} - -func binaryCallstacks(vr *vulncheck.Result) map[*vulncheck.Vuln]vulncheck.CallStack { - callstacks := map[*vulncheck.Vuln]vulncheck.CallStack{} - for _, vv := range uniqueVulns(vr.Vulns) { - f := &vulncheck.FuncNode{Package: vv.ImportSink, Name: vv.Symbol} - parts := strings.Split(vv.Symbol, ".") - if len(parts) != 1 { - f.RecvType = parts[0] - f.Name = parts[1] - } - callstacks[vv] = vulncheck.CallStack{vulncheck.StackEntry{Function: f}} - } - return callstacks -} - -// uniqueVulns does for binary mode what uniqueCallStack does for source mode. -// It tries not to report redundant symbols. Since there are no call stacks in -// binary mode, the following approximate approach is used. Do not report unexported -// symbols for a triple if there are some exported symbols. -// Otherwise, report all unexported symbols to avoid not reporting anything. -func uniqueVulns(vulns []*vulncheck.Vuln) []*vulncheck.Vuln { - type key struct { - id string - pkg string - mod string - } - hasExported := make(map[key]bool) - for _, v := range vulns { - if isExported(v.Symbol) { - k := key{id: v.OSV.ID, pkg: v.ImportSink.PkgPath, mod: v.ImportSink.Module.Path} - hasExported[k] = true - } - } - - var uniques []*vulncheck.Vuln - for _, v := range vulns { - k := key{id: v.OSV.ID, pkg: v.ImportSink.PkgPath, mod: v.ImportSink.Module.Path} - if isExported(v.Symbol) || !hasExported[k] { - uniques = append(uniques, v) - } - } - return uniques -} - -// isExported checks if the symbol is exported. Assumes that the -// symbol is of the form "identifier" or "identifier1.identifier2". -func isExported(symbol string) bool { - parts := strings.Split(symbol, ".") - if len(parts) == 1 { - return unicode.IsUpper(rune(symbol[0])) - } - return unicode.IsUpper(rune(parts[1][0])) + return vulncheck.Binary(ctx, handler, exe, &cfg.Config, client) } diff --git a/internal/scan/binary_test.go b/internal/scan/binary_test.go deleted file mode 100644 index f451453..0000000 --- a/internal/scan/binary_test.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2023 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package scan - -import ( - "testing" -) - -func TestIsExported(t *testing.T) { - for _, tc := range []struct { - symbol string - want bool - }{ - {"foo", false}, - {"Foo", true}, - {"x.foo", false}, - {"X.foo", false}, - {"x.Foo", true}, - {"X.Foo", true}, - } { - tc := tc - t.Run(tc.symbol, func(t *testing.T) { - if got := isExported(tc.symbol); tc.want != got { - t.Errorf("want %t; got %t", tc.want, got) - } - }) - } -} diff --git a/internal/scan/source.go b/internal/scan/source.go index c6d7eeb..d7ce45c 100644 --- a/internal/scan/source.go +++ b/internal/scan/source.go @@ -8,10 +8,10 @@ import ( "context" "fmt" "path/filepath" - "sort" "golang.org/x/tools/go/packages" "golang.org/x/vuln/internal/client" + "golang.org/x/vuln/internal/derrors" "golang.org/x/vuln/internal/govulncheck" "golang.org/x/vuln/internal/vulncheck" ) @@ -21,7 +21,9 @@ import ( // Vulnerabilities can be called (affecting the package, because a vulnerable // symbol is actually exercised) or just imported by the package // (likely having a non-affecting outcome). -func runSource(ctx context.Context, handler govulncheck.Handler, cfg *config, client *client.Client, dir string) error { +func runSource(ctx context.Context, handler govulncheck.Handler, cfg *config, client *client.Client, dir string) (err error) { + defer derrors.Wrap(&err, "govulncheck") + if len(cfg.patterns) == 0 { return nil } @@ -49,77 +51,7 @@ func runSource(ctx context.Context, handler govulncheck.Handler, cfg *config, cl if len(pkgs) == 0 { return nil // early exit } - vr, err := vulncheck.Source(ctx, handler, pkgs, &cfg.Config, client, graph) - if err != nil { - return err - } - callStacks := vulncheck.CallStacks(vr) - return emitCalledVulns(handler, callStacks) -} - -func emitCalledVulns(handler govulncheck.Handler, callstacks map[*vulncheck.Vuln]vulncheck.CallStack) error { - var vulns []*vulncheck.Vuln - - for v := range callstacks { - vulns = append(vulns, v) - } - - sort.SliceStable(vulns, func(i, j int) bool { - return vulns[i].Symbol < vulns[j].Symbol - }) - - for _, vuln := range vulns { - stack := callstacks[vuln] - if stack == nil { - continue - } - fixed := vulncheck.FixedVersion(vulncheck.ModPath(vuln.ImportSink.Module), vulncheck.ModVersion(vuln.ImportSink.Module), vuln.OSV.Affected) - handler.Finding(&govulncheck.Finding{ - OSV: vuln.OSV.ID, - FixedVersion: fixed, - Trace: tracefromEntries(stack), - }) - } - return nil -} - -// tracefromEntries creates a sequence of -// frames from vcs. Position of a Frame is the -// call position of the corresponding stack entry. -func tracefromEntries(vcs vulncheck.CallStack) []*govulncheck.Frame { - var frames []*govulncheck.Frame - for i := len(vcs) - 1; i >= 0; i-- { - e := vcs[i] - fr := frameFromPackage(e.Function.Package) - fr.Function = e.Function.Name - fr.Receiver = e.Function.Receiver() - if e.Call == nil || e.Call.Pos == nil { - fr.Position = nil - } else { - fr.Position = &govulncheck.Position{ - Filename: e.Call.Pos.Filename, - Offset: e.Call.Pos.Offset, - Line: e.Call.Pos.Line, - Column: e.Call.Pos.Column, - } - } - frames = append(frames, fr) - } - return frames -} - -func frameFromPackage(pkg *packages.Package) *govulncheck.Frame { - fr := &govulncheck.Frame{} - if pkg != nil { - fr.Module = pkg.Module.Path - fr.Version = pkg.Module.Version - fr.Package = pkg.PkgPath - } - if pkg.Module.Replace != nil { - fr.Module = pkg.Module.Replace.Path - fr.Version = pkg.Module.Replace.Version - } - return fr + return vulncheck.Source(ctx, handler, pkgs, &cfg.Config, client, graph) } // sourceProgressMessage returns a string of the form @@ -199,15 +131,15 @@ func depPkgs(topPkgs []*packages.Package) int { // and actionable error message to surface for the end user. func parseLoadError(err error, dir string, pkgs bool) error { if !fileExists(filepath.Join(dir, "go.mod")) { - return fmt.Errorf("govulncheck: %v", errNoGoMod) + return errNoGoMod } if isGoVersionMismatchError(err) { - return fmt.Errorf("govulncheck: %v\n\n%v", errGoVersionMismatch, err) + return fmt.Errorf("%v\n\n%v", errGoVersionMismatch, err) } level := "modules" if pkgs { level = "packages" } - return fmt.Errorf("govulncheck: loading %s: %w", level, err) + return fmt.Errorf("loading %s: %w", level, err) } diff --git a/internal/vulncheck/binary.go b/internal/vulncheck/binary.go index a35a3a0..741726e 100644 --- a/internal/vulncheck/binary.go +++ b/internal/vulncheck/binary.go @@ -20,9 +20,20 @@ import ( "golang.org/x/vuln/internal/vulncheck/internal/buildinfo" ) -// Binary detects presence of vulnerable symbols in exe. +// Binary detects presence of vulnerable symbols in exe and +// emits findings to exe. +func Binary(ctx context.Context, handler govulncheck.Handler, exe io.ReaderAt, cfg *govulncheck.Config, client *client.Client) error { + vr, err := binary(ctx, exe, cfg, client) + if err != nil { + return err + } + callstacks := binaryCallstacks(vr) + return emitBinaryResult(handler, vr, callstacks) +} + +// binary detects presence of vulnerable symbols in exe. // The Calls, Imports, and Requires fields on Result will be empty. -func Binary(ctx context.Context, exe io.ReaderAt, cfg *govulncheck.Config, client *client.Client) (_ *Result, err error) { +func binary(ctx context.Context, exe io.ReaderAt, cfg *govulncheck.Config, client *client.Client) (_ *Result, err error) { mods, packageSymbols, bi, err := buildinfo.ExtractPackagesAndSymbols(exe) if err != nil { return nil, fmt.Errorf("could not parse provided binary: %v", err) diff --git a/internal/vulncheck/binary_test.go b/internal/vulncheck/binary_test.go index 7b1abfc..0494ea2 100644 --- a/internal/vulncheck/binary_test.go +++ b/internal/vulncheck/binary_test.go @@ -119,7 +119,7 @@ func TestBinary(t *testing.T) { // Test imports only mode cfg := &govulncheck.Config{ScanLevel: "package"} - res, err := Binary(context.Background(), bin, cfg, c) + res, err := binary(context.Background(), bin, cfg, c) if err != nil { t.Fatal(err) } @@ -143,7 +143,7 @@ func TestBinary(t *testing.T) { // Test the symbols (non-import mode) cfg.ScanLevel = "symbol" - res, err = Binary(context.Background(), bin, cfg, c) + res, err = binary(context.Background(), bin, cfg, c) if err != nil { t.Fatal(err) } @@ -237,7 +237,7 @@ func Vuln() { } cfg := &govulncheck.Config{ScanLevel: "symbol"} - res, err := Binary(context.Background(), bin, cfg, c) + res, err := binary(context.Background(), bin, cfg, c) if err != nil { t.Fatal(err) } diff --git a/internal/vulncheck/emit.go b/internal/vulncheck/emit.go new file mode 100644 index 0000000..b033bcc --- /dev/null +++ b/internal/vulncheck/emit.go @@ -0,0 +1,174 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package vulncheck + +import ( + "sort" + + "golang.org/x/tools/go/packages" + "golang.org/x/vuln/internal" + "golang.org/x/vuln/internal/govulncheck" + "golang.org/x/vuln/internal/osv" +) + +func emitCalledVulns(handler govulncheck.Handler, callstacks map[*Vuln]CallStack) error { + var vulns []*Vuln + + for v := range callstacks { + vulns = append(vulns, v) + } + + sort.SliceStable(vulns, func(i, j int) bool { + return vulns[i].Symbol < vulns[j].Symbol + }) + + for _, vuln := range vulns { + stack := callstacks[vuln] + if stack == nil { + continue + } + fixed := FixedVersion(modPath(vuln.ImportSink.Module), modVersion(vuln.ImportSink.Module), vuln.OSV.Affected) + handler.Finding(&govulncheck.Finding{ + OSV: vuln.OSV.ID, + FixedVersion: fixed, + Trace: tracefromEntries(stack), + }) + } + return nil +} + +func emitModuleFindings(handler govulncheck.Handler, modVulns moduleVulnerabilities) map[string]*osv.Entry { + osvs := make(map[string]*osv.Entry) + for _, vuln := range modVulns { + for _, osv := range vuln.Vulns { + if _, found := osvs[osv.ID]; !found { + handler.OSV(osv) + } + handler.Finding(&govulncheck.Finding{ + OSV: osv.ID, + FixedVersion: FixedVersion(modPath(vuln.Module), modVersion(vuln.Module), osv.Affected), + Trace: []*govulncheck.Frame{frameFromModule(vuln.Module, osv.Affected)}, + }) + } + } + return osvs +} + +func emitPackageFinding(handler govulncheck.Handler, vuln *Vuln) error { + return handler.Finding(&govulncheck.Finding{ + OSV: vuln.OSV.ID, + FixedVersion: FixedVersion(modPath(vuln.ImportSink.Module), modVersion(vuln.ImportSink.Module), vuln.OSV.Affected), + Trace: []*govulncheck.Frame{frameFromPackage(vuln.ImportSink)}, + }) +} + +// tracefromEntries creates a sequence of +// frames from vcs. Position of a Frame is the +// call position of the corresponding stack entry. +func tracefromEntries(vcs CallStack) []*govulncheck.Frame { + var frames []*govulncheck.Frame + for i := len(vcs) - 1; i >= 0; i-- { + e := vcs[i] + fr := frameFromPackage(e.Function.Package) + fr.Function = e.Function.Name + fr.Receiver = e.Function.Receiver() + if e.Call == nil || e.Call.Pos == nil { + fr.Position = nil + } else { + fr.Position = &govulncheck.Position{ + Filename: e.Call.Pos.Filename, + Offset: e.Call.Pos.Offset, + Line: e.Call.Pos.Line, + Column: e.Call.Pos.Column, + } + } + frames = append(frames, fr) + } + return frames +} + +func frameFromPackage(pkg *packages.Package) *govulncheck.Frame { + fr := &govulncheck.Frame{} + if pkg != nil { + fr.Module = pkg.Module.Path + fr.Version = pkg.Module.Version + fr.Package = pkg.PkgPath + } + if pkg.Module.Replace != nil { + fr.Module = pkg.Module.Replace.Path + fr.Version = pkg.Module.Replace.Version + } + return fr +} + +func frameFromModule(mod *packages.Module, affected []osv.Affected) *govulncheck.Frame { + fr := &govulncheck.Frame{ + Module: mod.Path, + Version: mod.Version, + } + + if mod.Path == internal.GoStdModulePath { + for _, a := range affected { + if a.Module.Path != mod.Path { + continue + } + fr.Package = a.EcosystemSpecific.Packages[0].Path + } + } + + if mod.Replace != nil { + fr.Module = mod.Replace.Path + fr.Version = mod.Replace.Version + } + + return fr +} + +func emitBinaryResult(handler govulncheck.Handler, vr *Result, callstacks map[*Vuln]CallStack) error { + osvs := map[string]*osv.Entry{} + // first deal with all the affected vulnerabilities + emitted := map[string]bool{} + seen := map[string]bool{} + emitFinding := func(finding *govulncheck.Finding) error { + if !seen[finding.OSV] { + seen[finding.OSV] = true + if err := handler.OSV(osvs[finding.OSV]); err != nil { + return err + } + } + return handler.Finding(finding) + } + + for _, vv := range vr.Vulns { + osvs[vv.OSV.ID] = vv.OSV + fixed := FixedVersion(modPath(vv.ImportSink.Module), modVersion(vv.ImportSink.Module), vv.OSV.Affected) + stack := callstacks[vv] + if stack == nil { + continue + } + emitted[vv.OSV.ID] = true + emitFinding(&govulncheck.Finding{ + OSV: vv.OSV.ID, + FixedVersion: fixed, + Trace: tracefromEntries(stack), + }) + } + for _, vv := range vr.Vulns { + if emitted[vv.OSV.ID] { + continue + } + stacks := callstacks[vv] + if len(stacks) != 0 { + continue + } + emitted[vv.OSV.ID] = true + emitFinding(&govulncheck.Finding{ + OSV: vv.OSV.ID, + FixedVersion: FixedVersion(modPath(vv.ImportSink.Module), modVersion(vv.ImportSink.Module), vv.OSV.Affected), + Trace: []*govulncheck.Frame{frameFromPackage(vv.ImportSink)}, + }) + } + return nil +} diff --git a/internal/vulncheck/findings.go b/internal/vulncheck/findings.go deleted file mode 100644 index 83c6090..0000000 --- a/internal/vulncheck/findings.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2023 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package vulncheck - -import ( - "golang.org/x/tools/go/packages" - "golang.org/x/vuln/internal" - "golang.org/x/vuln/internal/govulncheck" - "golang.org/x/vuln/internal/osv" -) - -func frameFromPackage(pkg *packages.Package) *govulncheck.Frame { - fr := &govulncheck.Frame{ - Module: pkg.Module.Path, - Version: pkg.Module.Version, - Package: pkg.PkgPath, - } - - if pkg.Module.Replace != nil { - fr.Module = pkg.Module.Replace.Path - fr.Version = pkg.Module.Replace.Version - } - return fr -} - -func frameFromModule(mod *packages.Module, affected []osv.Affected) *govulncheck.Frame { - fr := &govulncheck.Frame{ - Module: mod.Path, - Version: mod.Version, - } - - if mod.Path == internal.GoStdModulePath { - for _, a := range affected { - if a.Module.Path != mod.Path { - continue - } - fr.Package = a.EcosystemSpecific.Packages[0].Path - } - } - - if mod.Replace != nil { - fr.Module = mod.Replace.Path - fr.Version = mod.Replace.Version - } - - return fr -} - -func emitModuleFindings(modVulns moduleVulnerabilities, handler govulncheck.Handler) map[string]*osv.Entry { - osvs := make(map[string]*osv.Entry) - for _, vuln := range modVulns { - for _, osv := range vuln.Vulns { - if _, found := osvs[osv.ID]; !found { - handler.OSV(osv) - } - handler.Finding(&govulncheck.Finding{ - OSV: osv.ID, - FixedVersion: FixedVersion(ModPath(vuln.Module), ModVersion(vuln.Module), osv.Affected), - Trace: []*govulncheck.Frame{frameFromModule(vuln.Module, osv.Affected)}, - }) - } - } - return osvs -} - -func emitPackageFinding(vuln *Vuln, handler govulncheck.Handler) error { - return handler.Finding(&govulncheck.Finding{ - OSV: vuln.OSV.ID, - FixedVersion: FixedVersion(ModPath(vuln.ImportSink.Module), ModVersion(vuln.ImportSink.Module), vuln.OSV.Affected), - Trace: []*govulncheck.Frame{frameFromPackage(vuln.ImportSink)}, - }) -} diff --git a/internal/vulncheck/source.go b/internal/vulncheck/source.go index 41b4223..124d704 100644 --- a/internal/vulncheck/source.go +++ b/internal/vulncheck/source.go @@ -16,7 +16,17 @@ import ( "golang.org/x/vuln/internal/osv" ) -// Source detects vulnerabilities in packages. The result will contain: +// Source detects vulnerabilities in pkgs and emits the findings to handler. +func Source(ctx context.Context, handler govulncheck.Handler, pkgs []*packages.Package, cfg *govulncheck.Config, client *client.Client, graph *PackageGraph) error { + vr, err := source(ctx, handler, pkgs, cfg, client, graph) + if err != nil { + return err + } + callStacks := sourceCallstacks(vr) + return emitCalledVulns(handler, callStacks) +} + +// source detects vulnerabilities in packages. The result will contain: // // 1) An ImportGraph related to an import of a package with some known // vulnerabilities. @@ -27,7 +37,7 @@ import ( // 3) A CallGraph leading to the use of a known vulnerable function or method. // // Assumes that pkgs are non-empty and belong to the same program. -func Source(ctx context.Context, handler govulncheck.Handler, pkgs []*packages.Package, cfg *govulncheck.Config, client *client.Client, graph *PackageGraph) (_ *Result, err error) { +func source(ctx context.Context, handler govulncheck.Handler, pkgs []*packages.Package, cfg *govulncheck.Config, client *client.Client, graph *PackageGraph) (_ *Result, err error) { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -63,7 +73,7 @@ func Source(ctx context.Context, handler govulncheck.Handler, pkgs []*packages.P modVulns = modVulns.filter("", "") result := &Result{} // instead of add to result, output using the handler - emitModuleFindings(modVulns, handler) + emitModuleFindings(handler, modVulns) if !cfg.ScanLevel.WantPackages() || len(modVulns) == 0 { return result, nil @@ -143,10 +153,10 @@ func vulnImportSlice(pkg *packages.Package, modVulns moduleVulnerabilities, resu if len(symbols) == 0 { symbols = allSymbols(pkg.Types) } - emitPackageFinding(&Vuln{ + emitPackageFinding(handler, &Vuln{ OSV: osv, ImportSink: pkg, - }, handler) + }) for _, symbol := range symbols { vuln := &Vuln{ OSV: osv, diff --git a/internal/vulncheck/source_test.go b/internal/vulncheck/source_test.go index e0f527c..8e5e41e 100644 --- a/internal/vulncheck/source_test.go +++ b/internal/vulncheck/source_test.go @@ -205,7 +205,7 @@ func TestCalls(t *testing.T) { } cfg := &govulncheck.Config{ScanLevel: "symbol"} - result, err := Source(context.Background(), test.NewMockHandler(), pkgs, cfg, c, graph) + result, err := source(context.Background(), test.NewMockHandler(), pkgs, cfg, c, graph) if err != nil { t.Fatal(err) } @@ -307,7 +307,7 @@ func TestAllSymbolsVulnerable(t *testing.T) { } cfg := &govulncheck.Config{ScanLevel: "symbol"} - result, err := Source(context.Background(), test.NewMockHandler(), pkgs, cfg, client, graph) + result, err := source(context.Background(), test.NewMockHandler(), pkgs, cfg, client, graph) if err != nil { t.Fatal(err) } @@ -377,7 +377,7 @@ func TestNoSyntheticNodes(t *testing.T) { } cfg := &govulncheck.Config{ScanLevel: "symbol"} - result, err := Source(context.Background(), test.NewMockHandler(), pkgs, cfg, c, graph) + result, err := source(context.Background(), test.NewMockHandler(), pkgs, cfg, c, graph) if err != nil { t.Fatal(err) } @@ -397,7 +397,7 @@ func TestNoSyntheticNodes(t *testing.T) { t.Fatal("VulnData.Vuln1 should be deemed a called vulnerability") } - stack := CallStacks(result)[vuln] + stack := sourceCallstacks(result)[vuln] // We don't want the call stack X -> *VulnData.Vuln1 (wrapper) -> VulnData.Vuln1. // We want X -> VulnData.Vuln1. if len(stack) != 2 { @@ -457,7 +457,7 @@ func TestRecursion(t *testing.T) { } cfg := &govulncheck.Config{ScanLevel: "symbol"} - result, err := Source(context.Background(), test.NewMockHandler(), pkgs, cfg, c, graph) + result, err := source(context.Background(), test.NewMockHandler(), pkgs, cfg, c, graph) if err != nil { t.Fatal(err) } @@ -522,7 +522,7 @@ func TestIssue57174(t *testing.T) { } cfg := &govulncheck.Config{ScanLevel: "symbol"} - _, err = Source(context.Background(), test.NewMockHandler(), pkgs, cfg, c, graph) + _, err = source(context.Background(), test.NewMockHandler(), pkgs, cfg, c, graph) if err != nil { t.Fatal(err) } diff --git a/internal/vulncheck/utils.go b/internal/vulncheck/utils.go index e976d88..2861091 100644 --- a/internal/vulncheck/utils.go +++ b/internal/vulncheck/utils.go @@ -335,14 +335,14 @@ func fixNegated(fix string, affected []osv.Affected) bool { return false } -func ModPath(mod *packages.Module) string { +func modPath(mod *packages.Module) string { if mod.Replace != nil { return mod.Replace.Path } return mod.Path } -func ModVersion(mod *packages.Module) string { +func modVersion(mod *packages.Module) string { if mod.Replace != nil { return mod.Replace.Version } diff --git a/internal/vulncheck/witness.go b/internal/vulncheck/witness.go index 7de3758..27f3597 100644 --- a/internal/vulncheck/witness.go +++ b/internal/vulncheck/witness.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "sync" + "unicode" "golang.org/x/tools/go/packages" ) @@ -32,18 +33,18 @@ type StackEntry struct { Call *CallSite } -// CallStacks returns representative call stacks for each +// sourceCallstacks returns representative call stacks for each // vulnerability in res. The returned call stacks are heuristically // ordered by how seemingly easy is to understand them: shorter // call stacks with less dynamic call sites appear earlier in the // returned slices. // -// CallStacks performs a breadth-first search of res.CallGraph starting -// at the vulnerable symbol and going up until reaching an entry +// sourceCallstacks performs a breadth-first search of res.CallGraph +// starting at the vulnerable symbol and going up until reaching an entry // function or method in res.CallGraph.Entries. During this search, // each function is visited at most once to avoid potential // exponential explosion. Hence, not all call stacks are analyzed. -func CallStacks(res *Result) map[*Vuln]CallStack { +func sourceCallstacks(res *Result) map[*Vuln]CallStack { var ( wg sync.WaitGroup mu sync.Mutex @@ -53,7 +54,7 @@ func CallStacks(res *Result) map[*Vuln]CallStack { vuln := vuln wg.Add(1) go func() { - cs := callStack(vuln, res) + cs := sourceCallstack(vuln, res) mu.Lock() stackPerVuln[vuln] = cs mu.Unlock() @@ -66,10 +67,10 @@ func CallStacks(res *Result) map[*Vuln]CallStack { return stackPerVuln } -// callStack finds a representative call stack for vuln. +// sourceCallstack finds a representative call stack for vuln. // This is a shortest unique call stack with the least // number of dynamic call sites. -func callStack(vuln *Vuln, res *Result) CallStack { +func sourceCallstack(vuln *Vuln, res *Result) CallStack { vulnSink := vuln.CallSink if vulnSink == nil { return nil @@ -390,3 +391,57 @@ func isInit(f *FuncNode) bool { // positive integer. Implicit inits are named simply "init". return f.Name == "init" || strings.HasPrefix(f.Name, "init#") } + +// binaryCallstacks computes representative call stacks for binary results. +func binaryCallstacks(vr *Result) map[*Vuln]CallStack { + callstacks := map[*Vuln]CallStack{} + for _, vv := range uniqueVulns(vr.Vulns) { + f := &FuncNode{Package: vv.ImportSink, Name: vv.Symbol} + parts := strings.Split(vv.Symbol, ".") + if len(parts) != 1 { + f.RecvType = parts[0] + f.Name = parts[1] + } + callstacks[vv] = CallStack{StackEntry{Function: f}} + } + return callstacks +} + +// uniqueVulns does for binary mode what sourceCallstacks does for source mode. +// It tries not to report redundant symbols. Since there are no call stacks in +// binary mode, the following approximate approach is used. Do not report unexported +// symbols for a triple if there are some exported symbols. +// Otherwise, report all unexported symbols to avoid not reporting anything. +func uniqueVulns(vulns []*Vuln) []*Vuln { + type key struct { + id string + pkg string + mod string + } + hasExported := make(map[key]bool) + for _, v := range vulns { + if isExported(v.Symbol) { + k := key{id: v.OSV.ID, pkg: v.ImportSink.PkgPath, mod: v.ImportSink.Module.Path} + hasExported[k] = true + } + } + + var uniques []*Vuln + for _, v := range vulns { + k := key{id: v.OSV.ID, pkg: v.ImportSink.PkgPath, mod: v.ImportSink.Module.Path} + if isExported(v.Symbol) || !hasExported[k] { + uniques = append(uniques, v) + } + } + return uniques +} + +// isExported checks if the symbol is exported. Assumes that the +// symbol is of the form "identifier" or "identifier1.identifier2". +func isExported(symbol string) bool { + parts := strings.Split(symbol, ".") + if len(parts) == 1 { + return unicode.IsUpper(rune(symbol[0])) + } + return unicode.IsUpper(rune(parts[1][0])) +} diff --git a/internal/vulncheck/witness_test.go b/internal/vulncheck/witness_test.go index dc770ad..82a8cc3 100644 --- a/internal/vulncheck/witness_test.go +++ b/internal/vulncheck/witness_test.go @@ -36,7 +36,7 @@ func stacksToString(stacks map[*Vuln]CallStack) map[string]string { return m } -func TestCallStacks(t *testing.T) { +func TestSourceCallstacks(t *testing.T) { // Call graph structure for the test program // entry1 entry2 // | | @@ -66,13 +66,13 @@ func TestCallStacks(t *testing.T) { "vuln2": "entry2->interm2->vuln2", } - stacks := CallStacks(res) + stacks := sourceCallstacks(res) if got := stacksToString(stacks); !reflect.DeepEqual(want, got) { t.Errorf("want %v; got %v", want, got) } } -func TestUniqueCallStack(t *testing.T) { +func TestSourceUniqueCallStack(t *testing.T) { // Call graph structure for the test program // entry1 entry2 // | | @@ -102,7 +102,7 @@ func TestUniqueCallStack(t *testing.T) { "vuln2": "entry2->interm1->interm2->vuln2", } - stacks := CallStacks(res) + stacks := sourceCallstacks(res) if got := stacksToString(stacks); !reflect.DeepEqual(want, got) { t.Errorf("want %v; got %v", want, got) } @@ -190,12 +190,12 @@ func TestInits(t *testing.T) { t.Fatal("failed to load x test package") } cfg := &govulncheck.Config{ScanLevel: "symbol"} - result, err := Source(context.Background(), test.NewMockHandler(), pkgs, cfg, testClient, graph) + result, err := source(context.Background(), test.NewMockHandler(), pkgs, cfg, testClient, graph) if err != nil { t.Fatal(err) } - cs := CallStacks(result) + cs := sourceCallstacks(result) want := map[string][]string{ "A": { // Entry init's position is the package statement. @@ -243,3 +243,24 @@ func fullStacksToString(callStacks map[*Vuln]CallStack) map[string][]string { } return m } + +func TestIsExported(t *testing.T) { + for _, tc := range []struct { + symbol string + want bool + }{ + {"foo", false}, + {"Foo", true}, + {"x.foo", false}, + {"X.foo", false}, + {"x.Foo", true}, + {"X.Foo", true}, + } { + tc := tc + t.Run(tc.symbol, func(t *testing.T) { + if got := isExported(tc.symbol); tc.want != got { + t.Errorf("want %t; got %t", tc.want, got) + } + }) + } +}