internal/worker: unify some firestore functions

Where possible, unify Firestore operations so that there
are not separate functions for adding/updating/fetching
a record based on its type (CVE/GHSA). The intended operation
is inferred by given ID(s).

Change-Id: Ic82e3ab4c9d519c3101f95444bc0ad306fa2a14e
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/588759
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Tatiana Bradley 2024-05-21 14:23:02 -04:00
Родитель 1283e469ae
Коммит a2650ed283
10 изменённых файлов: 207 добавлений и 166 удалений

Просмотреть файл

@ -273,7 +273,7 @@ func createIssuesCommand(ctx context.Context) error {
func showCommand(ctx context.Context, ids []string) error {
for _, id := range ids {
r, err := cfg.Store.GetCVE4Record(ctx, id)
r, err := cfg.Store.GetRecord(ctx, id)
if err != nil {
return err
}

Просмотреть файл

@ -33,12 +33,12 @@ func updateFalsePositives(ctx context.Context, st store.Store) (err error) {
old := oldRecords[i]
var err error
if old == nil {
err = tx.CreateCVE4Record(cr)
err = tx.CreateRecord(cr)
} else if old.CommitHash != cr.CommitHash && !old.CommitTime.IsZero() && old.CommitTime.Before(cr.CommitTime) {
// If the false positive data is more recent than what is in
// the store, then update the DB. But ignore records whose
// commit time hasn't been populated.
err = tx.SetCVE4Record(cr)
err = tx.SetRecord(cr)
}
if err != nil {
return err

Просмотреть файл

@ -13,6 +13,8 @@ import (
"cloud.google.com/go/firestore"
"golang.org/x/vulndb/internal/derrors"
"golang.org/x/vulndb/internal/idstr"
"golang.org/x/vulndb/internal/report"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/grpc/codes"
@ -104,19 +106,16 @@ func (fs *FireStore) SetCommitUpdateRecord(ctx context.Context, r *CommitUpdateR
return err
}
// GetCVE4Record implements store.GetCVE4Record.
func (fs *FireStore) GetCVE4Record(ctx context.Context, id string) (_ *CVE4Record, err error) {
defer derrors.Wrap(&err, "FireStore.GetCVE4Record(%q)", id)
// GetRecord implements store.GetRecord.
func (fs *FireStore) GetRecord(ctx context.Context, id string) (_ Record, err error) {
defer derrors.Wrap(&err, "FireStore.GetRecord(%q)", id)
docsnap, err := fs.cve4RecordRef(id).Get(ctx)
docsnap, err := fs.recordRef(id).Get(ctx)
if status.Code(err) == codes.NotFound {
return nil, nil
}
var cr CVE4Record
if err := docsnap.DataTo(&cr); err != nil {
return nil, err
}
return &cr, nil
return docsnapToRecord(docsnap)
}
// ListCommitUpdateRecords implements Store.ListCommitUpdateRecords.
@ -271,14 +270,15 @@ func (fs *FireStore) RunTransaction(ctx context.Context, f func(context.Context,
})
}
// cve4RecordRef returns a DocumentRef to the CVE4Record with id.
func (fs *FireStore) cve4RecordRef(id string) *firestore.DocumentRef {
return fs.nsDoc.Collection(cve4Collection).Doc(id)
}
// legacyGHSARecordRef returns a DocumentRef to the LegacyGHSARecord with id.
func (fs *FireStore) legacyGHSARecordRef(id string) *firestore.DocumentRef {
return fs.nsDoc.Collection(legacyGHSACollection).Doc(id)
func (fs *FireStore) recordRef(id string) *firestore.DocumentRef {
var collection string
switch {
case idstr.IsGHSA(id):
collection = legacyGHSACollection
case idstr.IsCVE(id):
collection = cve4Collection
}
return fs.nsDoc.Collection(collection).Doc(id)
}
// fsTransaction implements Transaction
@ -287,24 +287,15 @@ type fsTransaction struct {
t *firestore.Transaction
}
// CreateCVE4Record implements Transaction.CreateCVE4Record.
func (tx *fsTransaction) CreateCVE4Record(r *CVE4Record) (err error) {
defer derrors.Wrap(&err, "fsTransaction.CreateCVE4Record(%s)", r.ID)
// SetRecord implements Transaction.SetRecord.
func (tx *fsTransaction) SetRecord(r Record) (err error) {
defer derrors.Wrap(&err, "fsTransaction.SetRecord(%s)", r.GetID())
if err := r.Validate(); err != nil {
return err
}
return tx.t.Create(tx.s.cve4RecordRef(r.ID), r)
}
// SetCVE4Record implements Transaction.SetCVE4Record.
func (tx *fsTransaction) SetCVE4Record(r *CVE4Record) (err error) {
defer derrors.Wrap(&err, "fsTransaction.SetCVE4Record(%s)", r.ID)
if err := r.Validate(); err != nil {
return err
}
return tx.t.Set(tx.s.cve4RecordRef(r.ID), r)
return tx.t.Set(tx.s.recordRef(r.GetID()), r)
}
// GetCVE4Records implements Transaction.GetCVE4Records.
@ -335,33 +326,56 @@ func docsnapsToCVE4Records(docsnaps []*firestore.DocumentSnapshot) ([]*CVE4Recor
return crs, nil
}
// CreateLegacyGHSARecord implements Transaction.CreateLegacyGHSARecord.
func (tx *fsTransaction) CreateLegacyGHSARecord(r *LegacyGHSARecord) (err error) {
defer derrors.Wrap(&err, "fsTransaction.CreateGHSARecord(%s)", r.GHSA.ID)
return tx.t.Create(tx.s.legacyGHSARecordRef(r.GHSA.ID), r)
type Record interface {
GetID() string
GetSource() report.Source
GetUnit() string
GetIssueReference() string
GetIssueCreatedAt() time.Time
GetTriageState() TriageState
Validate() error
}
// SetLegacyGHSARecord implements Transaction.SetLegacyGHSARecord.
func (tx *fsTransaction) SetLegacyGHSARecord(r *LegacyGHSARecord) (err error) {
defer derrors.Wrap(&err, "fsTransaction.SetGHSARecord(%s)", r.GHSA.ID)
func (tx *fsTransaction) CreateRecord(r Record) (err error) {
defer derrors.Wrap(&err, "fsTransaction.CreateRecord(%s)", r.GetID())
return tx.t.Set(tx.s.legacyGHSARecordRef(r.GHSA.ID), r)
if err := r.Validate(); err != nil {
return err
}
return tx.t.Create(tx.s.recordRef(r.GetID()), r)
}
// GetLegacyGHSARecord implements Transaction.GetLegacyGHSARecord.
func (tx *fsTransaction) GetLegacyGHSARecord(id string) (_ *LegacyGHSARecord, err error) {
defer derrors.Wrap(&err, "fsTransaction.GetGHSARecord(%s)", id)
// GetRecord implements Transaction.GetRecord.
func (tx *fsTransaction) GetRecord(id string) (_ Record, err error) {
defer derrors.Wrap(&err, "fsTransaction.GetRecord(%s)", id)
docsnap, err := tx.t.Get(tx.s.legacyGHSARecordRef(id))
docsnap, err := tx.t.Get(tx.s.recordRef(id))
if status.Code(err) == codes.NotFound {
return nil, nil
}
var gr LegacyGHSARecord
if err := docsnap.DataTo(&gr); err != nil {
return docsnapToRecord(docsnap)
}
func docsnapToRecord(docsnap *firestore.DocumentSnapshot) (Record, error) {
id := docsnap.Ref.ID
var r Record
switch {
case idstr.IsGHSA(id):
r = new(LegacyGHSARecord)
case idstr.IsCVE(id):
r = new(CVE4Record)
default:
return nil, fmt.Errorf("id %s is not a CVE or GHSA id", id)
}
if err := docsnap.DataTo(r); err != nil {
return nil, err
}
return &gr, nil
return r, nil
}
// GetLegacyGHSARecords implements Transaction.GetLegacyGHSARecords.

Просмотреть файл

@ -12,6 +12,8 @@ import (
"sort"
"sync"
"time"
"golang.org/x/vulndb/internal/idstr"
)
// MemStore is an in-memory implementation of Store, for testing.
@ -83,8 +85,14 @@ func (ms *MemStore) ListCommitUpdateRecords(_ context.Context, limit int) ([]*Co
}
// GetCVE4Record implements store.GetCVE4Record.
func (ms *MemStore) GetCVE4Record(ctx context.Context, id string) (*CVE4Record, error) {
return ms.cve4Records[id], nil
func (ms *MemStore) GetRecord(_ context.Context, id string) (Record, error) {
switch {
case idstr.IsGHSA(id):
return ms.legacyGHSARecords[id], nil
case idstr.IsCVE(id):
return ms.cve4Records[id], nil
}
return nil, fmt.Errorf("%s is not a CVE or GHSA id", id)
}
// ListCVE4RecordsWithTriageState implements Store.ListCVE4RecordsWithTriageState.
@ -159,24 +167,28 @@ type memTransaction struct {
ms *MemStore
}
// CreateCVE4Record implements Transaction.CreateCVE4Record.
func (tx *memTransaction) CreateCVE4Record(r *CVE4Record) error {
if err := r.Validate(); err != nil {
return err
}
tx.ms.cve4Records[r.ID] = r
return nil
}
// SetCVE4Record implements Transaction.SetCVE4Record.
func (tx *memTransaction) SetCVE4Record(r *CVE4Record) error {
func (tx *memTransaction) SetRecord(r Record) error {
if err := r.Validate(); err != nil {
return err
}
if tx.ms.cve4Records[r.ID] == nil {
return fmt.Errorf("CVE4Record with ID %q not found", r.ID)
id := r.GetID()
switch v := r.(type) {
case *LegacyGHSARecord:
if _, ok := tx.ms.legacyGHSARecords[id]; !ok {
return fmt.Errorf("LegacyGHSARecord %s does not exist", id)
}
tx.ms.legacyGHSARecords[id] = v
case *CVE4Record:
if tx.ms.cve4Records[id] == nil {
return fmt.Errorf("CVE4Record with ID %q not found", id)
}
tx.ms.cve4Records[id] = v
default:
return fmt.Errorf("unrecognized record type %T", r)
}
tx.ms.cve4Records[r.ID] = r
return nil
}
@ -196,28 +208,44 @@ func (tx *memTransaction) GetCVE4Records(startID, endID string) ([]*CVE4Record,
return crs, nil
}
// CreateLegacyGHSARecord implements Transaction.CreateLegacyGHSARecord.
func (tx *memTransaction) CreateLegacyGHSARecord(r *LegacyGHSARecord) error {
if _, ok := tx.ms.legacyGHSARecords[r.GHSA.ID]; ok {
return fmt.Errorf("LegacyGHSARecord %s already exists", r.GHSA.ID)
// CreateRecord implements Transaction.CreateRecord.
func (tx *memTransaction) CreateRecord(r Record) error {
if err := r.Validate(); err != nil {
return err
}
tx.ms.legacyGHSARecords[r.GHSA.ID] = r
return nil
}
// SetLegacyGHSARecord implements Transaction.SetLegacyGHSARecord.
func (tx *memTransaction) SetLegacyGHSARecord(r *LegacyGHSARecord) error {
if _, ok := tx.ms.legacyGHSARecords[r.GHSA.ID]; !ok {
return fmt.Errorf("LegacyGHSARecord %s does not exist", r.GHSA.ID)
id := r.GetID()
switch v := r.(type) {
case *LegacyGHSARecord:
if _, ok := tx.ms.legacyGHSARecords[id]; ok {
return fmt.Errorf("LegacyGHSARecord %s already exists", id)
}
tx.ms.legacyGHSARecords[id] = v
return nil
case *CVE4Record:
if _, ok := tx.ms.cve4Records[id]; ok {
return fmt.Errorf("CVE4Record %s already exists", id)
}
tx.ms.cve4Records[id] = v
return nil
default:
return fmt.Errorf("unrecognized record type %T", r)
}
tx.ms.legacyGHSARecords[r.GHSA.ID] = r
return nil
}
// GetLegacyGHSARecord implements Transaction.GetLegacyGHSARecord.
func (tx *memTransaction) GetLegacyGHSARecord(id string) (*LegacyGHSARecord, error) {
if r, ok := tx.ms.legacyGHSARecords[id]; ok {
return r, nil
func (tx *memTransaction) GetRecord(id string) (Record, error) {
switch {
case idstr.IsGHSA(id):
if r, ok := tx.ms.legacyGHSARecords[id]; ok {
return r, nil
}
case idstr.IsCVE(id):
if r, ok := tx.ms.cve4Records[id]; ok {
return r, nil
}
default:
return nil, fmt.Errorf("id %s is not a CVE or GHSA id", id)
}
return nil, nil
}

Просмотреть файл

@ -69,6 +69,7 @@ func (r *CVE4Record) GetUnit() string { return r.Module }
func (r *CVE4Record) GetSource() report.Source { return r.CVE }
func (r *CVE4Record) GetIssueReference() string { return r.IssueReference }
func (r *CVE4Record) GetIssueCreatedAt() time.Time { return r.IssueCreatedAt }
func (r *CVE4Record) GetTriageState() TriageState { return r.TriageState }
// Validate returns an error if the CVE4Record is not valid.
func (r *CVE4Record) Validate() error {
@ -203,6 +204,8 @@ func (r *LegacyGHSARecord) GetUnit() string { return r.GHSA.Vulns[0
func (r *LegacyGHSARecord) GetSource() report.Source { return r.GHSA }
func (r *LegacyGHSARecord) GetIssueReference() string { return r.IssueReference }
func (r *LegacyGHSARecord) GetIssueCreatedAt() time.Time { return r.IssueCreatedAt }
func (r *LegacyGHSARecord) GetTriageState() TriageState { return r.TriageState }
func (r *LegacyGHSARecord) Validate() error { return nil }
// A ModuleScanRecord holds information about a vulnerability scan of a module.
type ModuleScanRecord struct {
@ -246,10 +249,10 @@ type Store interface {
// least recent.
ListCommitUpdateRecords(ctx context.Context, limit int) ([]*CommitUpdateRecord, error)
// GetCVE4Record returns the CVE4Record with the given id. If not found, it returns (nil, nil).
GetCVE4Record(ctx context.Context, id string) (*CVE4Record, error)
// GetRecord returns the Record with the given id. If not found, it returns (nil, nil).
GetRecord(ctx context.Context, id string) (Record, error)
// ListCVE4RecordsWithTriageState returns all CVER4ecords with the given triage state,
// ListCVE4RecordsWithTriageState returns all CVE4Records with the given triage state,
// ordered by ID.
ListCVE4RecordsWithTriageState(ctx context.Context, ts TriageState) ([]*CVE4Record, error)
@ -278,30 +281,22 @@ type Store interface {
// Transaction supports store operations that run inside a transaction.
type Transaction interface {
// CreateCVE4Record creates a new CVE4Record. It is an error if one with the same ID
// already exists.
CreateCVE4Record(*CVE4Record) error
// CreateRecord creates a new record.
// It is an error if one with the same ID already exists.
CreateRecord(Record) error
// SetCVE4Record sets the CVE record in the database. It is
// an error if no such record exists.
SetCVE4Record(r *CVE4Record) error
// SetRecord sets the record in the database.
// It is an error if no such record exists.
SetRecord(Record) error
// GetCVE4Records retrieves CVE4Records for all CVE IDs between startID and
// GetRecord returns a single record by ID.
// If not found, it returns (nil, nil).
GetRecord(id string) (Record, error)
// GetRecords retrieves records for all CVE IDs between startID and
// endID, inclusive.
GetCVE4Records(startID, endID string) ([]*CVE4Record, error)
// CreateLegacyGHSARecord creates a new LegacyGHSARecord. It is an error if one with the same ID
// already exists.
CreateLegacyGHSARecord(*LegacyGHSARecord) error
// SetLegacyGHSARecord sets the GHSA record in the database. It is
// an error if no such record exists.
SetLegacyGHSARecord(*LegacyGHSARecord) error
// GetLegacyGHSARecord returns a single LegacyGHSARecord by GHSA ID.
// If not found, it returns (nil, nil).
GetLegacyGHSARecord(id string) (*LegacyGHSARecord, error)
// GetLegacyGHSARecords returns all the GHSARecords in the database.
GetLegacyGHSARecords() ([]*LegacyGHSARecord, error)
}

Просмотреть файл

@ -150,13 +150,14 @@ func testCVEs(t *testing.T, s Store) {
set := func(r *CVE4Record) *CVE4Record {
must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
return tx.SetCVE4Record(r)
return tx.SetRecord(r)
}))(t)
return must1(s.GetCVE4Record(ctx, r.ID))(t)
return must1(s.GetRecord(ctx, r.ID))(t).(*CVE4Record)
}
// Make sure the first record is the same that we created.
got := must1(s.GetCVE4Record(ctx, id1))(t)
r := must1(s.GetRecord(ctx, id1))(t)
got := r.(*CVE4Record)
diff(t, crs[0], got)
// Change the state and the commit hash.
@ -187,22 +188,26 @@ func testDirHashes(t *testing.T, s Store) {
}
}
var (
ghsa1, ghsa2, ghsa3, ghsa4, ghsa5 = "GHSA-xxxx-yyyy-1111", "GHSA-xxxx-yyyy-2222", "GHSA-xxxx-yyyy-3333", "GHSA-xxxx-yyyy-4444", "GHSA-xxxx-yyyy-5555"
)
func testGHSAs(t *testing.T, s Store) {
ctx := context.Background()
// Create two records.
gs := []*LegacyGHSARecord{
{
GHSA: &ghsa.SecurityAdvisory{ID: "g1", Summary: "one"},
GHSA: &ghsa.SecurityAdvisory{ID: ghsa1, Summary: "one"},
TriageState: TriageStateNeedsIssue,
},
{
GHSA: &ghsa.SecurityAdvisory{ID: "g2", Summary: "two"},
GHSA: &ghsa.SecurityAdvisory{ID: ghsa2, Summary: "two"},
TriageState: TriageStateNeedsIssue,
},
}
must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
for _, g := range gs {
if err := tx.CreateLegacyGHSARecord(g); err != nil {
if err := tx.CreateRecord(g); err != nil {
return err
}
}
@ -211,7 +216,7 @@ func testGHSAs(t *testing.T, s Store) {
// Modify one of them.
gs[1].TriageState = TriageStateIssueCreated
must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
return tx.SetLegacyGHSARecord(gs[1])
return tx.SetRecord(gs[1])
}))(t)
// Retrieve and compare.
var got []*LegacyGHSARecord
@ -231,9 +236,12 @@ func testGHSAs(t *testing.T, s Store) {
// Retrieve one record by GHSA ID.
var got0 *LegacyGHSARecord
must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
var err error
got0, err = tx.GetLegacyGHSARecord(gs[0].GetID())
return err
r, err := tx.GetRecord(gs[0].GetID())
if err != nil {
return err
}
got0 = r.(*LegacyGHSARecord)
return nil
}))(t)
if got, want := got0, gs[0]; !cmp.Equal(got, want) {
t.Errorf("got %+v, want %+v", got, want)
@ -294,7 +302,7 @@ func testModuleScanRecords(t *testing.T, s Store) {
func createCVE4Records(t *testing.T, ctx context.Context, s Store, crs []*CVE4Record) {
must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
for _, cr := range crs {
if err := tx.CreateCVE4Record(cr); err != nil {
if err := tx.CreateRecord(cr); err != nil {
return err
}
}

Просмотреть файл

@ -230,13 +230,13 @@ func (u *cveUpdater) updateBatch(ctx context.Context, batch []cvelistrepo.File)
}
// Add/modify the records.
for _, r := range toAdd {
if err := tx.CreateCVE4Record(r); err != nil {
if err := tx.CreateRecord(r); err != nil {
return err
}
numAdds++
}
for _, r := range toModify {
if err := tx.SetCVE4Record(r); err != nil {
if err := tx.SetRecord(r); err != nil {
return err
}
numMods++
@ -255,12 +255,12 @@ func (u *cveUpdater) updateBatch(ctx context.Context, batch []cvelistrepo.File)
// based on this.
func checkForAliases(cve *cve4.CVE, tx store.Transaction) (store.TriageState, error) {
for _, ghsaID := range cveutils.GetAliasGHSAs(cve) {
ghsa, err := tx.GetLegacyGHSARecord(ghsaID)
ghsa, err := tx.GetRecord(ghsaID)
if err != nil {
return "", err
}
if ghsa != nil {
return getTriageStateFromAlias(ghsa.TriageState), nil
return getTriageStateFromAlias(ghsa.GetTriageState()), nil
}
}
return store.TriageStateNeedsIssue, nil
@ -541,13 +541,13 @@ func updateGHSAs(ctx context.Context, listSAs GHSAListFunc, since time.Time, st
}
for _, r := range toAdd {
if err := tx.CreateLegacyGHSARecord(r); err != nil {
if err := tx.CreateRecord(r); err != nil {
return err
}
numAdded++
}
for _, r := range toUpdate {
if err := tx.SetLegacyGHSARecord(r); err != nil {
if err := tx.SetRecord(r); err != nil {
return err
}
numModified++

Просмотреть файл

@ -12,7 +12,6 @@ import (
"errors"
"flag"
"testing"
"time"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing/object"
@ -76,13 +75,12 @@ func modify(r, m *store.CVE4Record) *store.CVE4Record {
func TestNewCVE4Record(t *testing.T) {
// Check that NewCVE4Record with a TriageState gives a valid CVE4Record.
repo, err := gitrepo.ReadTxtarRepo(testRepoPath, time.Now())
_, commit, err := gitrepo.TxtarRepoAndHead(testRepoPath)
if err != nil {
t.Fatal(err)
}
commit := headCommit(t, repo)
pathname := "2021/0xxx/CVE-2021-0001.json"
cve, bh := readCVE4(t, repo, commit, pathname)
cve, bh := readCVE4(t, commit, pathname)
cr := store.NewCVE4Record(cve, pathname, bh, commit)
cr.TriageState = store.TriageStateNeedsIssue
if err := cr.Validate(); err != nil {
@ -92,11 +90,10 @@ func TestNewCVE4Record(t *testing.T) {
func TestDoUpdate(t *testing.T) {
ctx := context.Background()
repo, err := gitrepo.ReadTxtarRepo(testRepoPath, time.Now())
repo, commit, err := gitrepo.TxtarRepoAndHead(testRepoPath)
if err != nil {
t.Fatal(err)
}
commit := headCommit(t, repo)
cf, err := pkgsite.CacheFile(t)
if err != nil {
t.Fatal(err)
@ -130,7 +127,7 @@ func TestDoUpdate(t *testing.T) {
blobHashes []string
)
for _, p := range paths {
cve, bh := readCVE4(t, repo, commit, p)
cve, bh := readCVE4(t, commit, p)
cves = append(cves, cve)
blobHashes = append(blobHashes, bh)
}
@ -553,7 +550,7 @@ func TestGroupFilesByDirectory(t *testing.T) {
}
}
func readCVE4(t *testing.T, repo *git.Repository, commit *object.Commit, path string) (*cve4.CVE, string) {
func readCVE4(t *testing.T, commit *object.Commit, path string) (*cve4.CVE, string) {
cve, blobHash, err := ReadCVEAtPath(commit, path)
if err != nil {
t.Fatal(err)
@ -565,7 +562,7 @@ func createCVE4Records(t *testing.T, s store.Store, crs []*store.CVE4Record) {
err := s.RunTransaction(context.Background(), func(ctx context.Context, tx store.Transaction) error {
for _, cr := range crs {
copy := *cr
if err := tx.CreateCVE4Record(&copy); err != nil {
if err := tx.CreateRecord(&copy); err != nil {
return err
}
}
@ -580,7 +577,7 @@ func createLegacyGHSARecords(t *testing.T, s store.Store, grs []*store.LegacyGHS
err := s.RunTransaction(context.Background(), func(ctx context.Context, tx store.Transaction) error {
for _, gr := range grs {
copy := *gr
if err := tx.CreateLegacyGHSARecord(&copy); err != nil {
if err := tx.CreateRecord(&copy); err != nil {
return err
}
}

Просмотреть файл

@ -236,15 +236,15 @@ func createCVEIssues(ctx context.Context, st store.Store, client *issues.Client,
// Update the CVE4Record in the DB with issue information.
err = st.RunTransaction(ctx, func(ctx context.Context, tx store.Transaction) error {
rs, err := tx.GetCVE4Records(cr.ID, cr.ID)
r, err := tx.GetRecord(cr.ID)
if err != nil {
return err
}
cr := rs[0]
cr := r.(*store.CVE4Record)
cr.TriageState = store.TriageStateIssueCreated
cr.IssueReference = ref
cr.IssueCreatedAt = time.Now()
return tx.SetCVE4Record(cr)
return tx.SetRecord(cr)
})
if err != nil {
return err
@ -282,12 +282,13 @@ func createGHSAIssues(ctx context.Context, st store.Store, client *issues.Client
// Update the LegacyGHSARecord in the DB to reflect that the GHSA
// already has an advisory.
if err = st.RunTransaction(ctx, func(ctx context.Context, tx store.Transaction) error {
r, err := tx.GetLegacyGHSARecord(gr.GetID())
r, err := tx.GetRecord(gr.GetID())
if err != nil {
return err
}
r.TriageState = store.TriageStateHasVuln
return tx.SetLegacyGHSARecord(r)
g := r.(*store.LegacyGHSARecord)
g.TriageState = store.TriageStateHasVuln
return tx.SetRecord(g)
}); err != nil {
return err
}
@ -300,14 +301,16 @@ func createGHSAIssues(ctx context.Context, st store.Store, client *issues.Client
}
// Update the LegacyGHSARecord in the DB with issue information.
err = st.RunTransaction(ctx, func(ctx context.Context, tx store.Transaction) error {
r, err := tx.GetLegacyGHSARecord(gr.GetID())
r, err := tx.GetRecord(gr.GetID())
if err != nil {
return err
}
r.TriageState = store.TriageStateIssueCreated
r.IssueReference = ref
r.IssueCreatedAt = time.Now()
return tx.SetLegacyGHSARecord(r)
g := r.(*store.LegacyGHSARecord)
g.TriageState = store.TriageStateIssueCreated
g.IssueReference = ref
g.IssueCreatedAt = time.Now()
return tx.SetRecord(g)
})
if err != nil {
return err
@ -355,15 +358,7 @@ func NewIssueBody(r *report.Report, rc *report.Client) (body string, err error)
return b.String(), nil
}
type storeRecord interface {
GetID() string
GetSource() report.Source
GetUnit() string
GetIssueReference() string
GetIssueCreatedAt() time.Time
}
func createIssue(ctx context.Context, r storeRecord, client *issues.Client, pc *proxy.Client, rc *report.Client) (ref string, err error) {
func createIssue(ctx context.Context, r store.Record, client *issues.Client, pc *proxy.Client, rc *report.Client) (ref string, err error) {
id := r.GetID()
defer derrors.Wrap(&err, "createIssue(%s)", id)

Просмотреть файл

@ -97,6 +97,10 @@ func TestCheckUpdate(t *testing.T) {
}
}
var (
ghsa1, ghsa2, ghsa3, ghsa4, ghsa5, ghsa6 = "GHSA-xxxx-yyyy-1111", "GHSA-xxxx-yyyy-2222", "GHSA-xxxx-yyyy-3333", "GHSA-xxxx-yyyy-4444", "GHSA-xxxx-yyyy-5555", "GHSA-xxxx-yyyy-6666"
)
func TestCreateIssues(t *testing.T) {
ctx := context.Background()
mstore := store.NewMemStore()
@ -125,20 +129,20 @@ func TestCreateIssues(t *testing.T) {
crs := []*store.CVE4Record{
{
ID: "ID1",
ID: "CVE-2000-0001",
BlobHash: "bh1",
CommitHash: "ch",
CommitTime: ctime,
Path: "path1",
CVE: &cve4.CVE{
Metadata: cve4.Metadata{
ID: "ID1",
ID: "CVE-2000-0001",
},
},
TriageState: store.TriageStateNeedsIssue,
},
{
ID: "ID2",
ID: "CVE-2000-0002",
BlobHash: "bh2",
CommitHash: "ch",
CommitTime: ctime,
@ -146,7 +150,7 @@ func TestCreateIssues(t *testing.T) {
TriageState: store.TriageStateNoActionNeeded,
},
{
ID: "ID3",
ID: "CVE-2000-0003",
BlobHash: "bh3",
CommitHash: "ch",
CommitTime: ctime,
@ -158,37 +162,37 @@ func TestCreateIssues(t *testing.T) {
grs := []*store.LegacyGHSARecord{
{
GHSA: &ghsa.SecurityAdvisory{
ID: "g1",
ID: ghsa1,
Vulns: []*ghsa.Vuln{{Package: "p1"}},
},
TriageState: store.TriageStateNeedsIssue,
},
{
GHSA: &ghsa.SecurityAdvisory{
ID: "g2",
ID: ghsa2,
Vulns: []*ghsa.Vuln{{Package: "p2"}},
},
TriageState: store.TriageStateNoActionNeeded,
},
{
GHSA: &ghsa.SecurityAdvisory{
ID: "g3",
ID: ghsa3,
Vulns: []*ghsa.Vuln{{Package: "p3"}},
},
TriageState: store.TriageStateIssueCreated,
},
{
GHSA: &ghsa.SecurityAdvisory{
ID: "g4",
ID: ghsa4,
Vulns: []*ghsa.Vuln{{Package: "p4"}},
},
TriageState: store.TriageStateAlias,
},
{
GHSA: &ghsa.SecurityAdvisory{
ID: "g5",
ID: ghsa5,
Vulns: []*ghsa.Vuln{{Package: "p1"}},
Identifiers: []ghsa.Identifier{{Type: "GHSA", Value: "g5"}},
Identifiers: []ghsa.Identifier{{Type: "GHSA", Value: ghsa5}},
},
TriageState: store.TriageStateNeedsIssue,
},
@ -197,7 +201,7 @@ func TestCreateIssues(t *testing.T) {
// Add an existing report with GHSA "g5".
rc, err := report.NewTestClient(map[string]*report.Report{
"data/reports/GO-1999-0001.yaml": {GHSAs: []string{"g5"}},
"data/reports/GO-1999-0001.yaml": {GHSAs: []string{ghsa5}},
})
if err != nil {
t.Fatal(err)
@ -447,24 +451,24 @@ func TestUpdateGHSAs(t *testing.T) {
ctx := context.Background()
sas := []*ghsa.SecurityAdvisory{
{
ID: "g1",
ID: ghsa1,
UpdatedAt: day(2021, 10, 1),
},
{
ID: "g2",
ID: ghsa2,
UpdatedAt: day(2021, 11, 1),
},
{
ID: "g3",
ID: ghsa3,
UpdatedAt: day(2021, 12, 1),
},
{
ID: "g4",
ID: ghsa4,
Identifiers: []ghsa.Identifier{{Type: "CVE", Value: "CVE-2000-1111"}},
UpdatedAt: day(2021, 12, 1),
},
{
ID: "g5",
ID: ghsa5,
Identifiers: []ghsa.Identifier{{Type: "CVE", Value: "CVE-2000-2222"}},
UpdatedAt: day(2021, 12, 1),
},
@ -532,7 +536,7 @@ func TestUpdateGHSAs(t *testing.T) {
}
want[0].GHSA = sas[0]
sas = append(sas, &ghsa.SecurityAdvisory{
ID: "g6",
ID: ghsa6,
UpdatedAt: day(2021, 12, 2),
})
listSAs = fakeListFunc(sas)