Move downstream operations to class

This commit is contained in:
Avram Lubkin 2023-06-14 12:54:14 -04:00 коммит произвёл Avram Lubkin
Родитель d3adb46aad
Коммит fca2841576
4 изменённых файлов: 244 добавлений и 247 удалений

Просмотреть файл

@ -10,8 +10,7 @@ from typing import Optional, Sequence
from comma.cli.parser import parse_args from comma.cli.parser import parse_args
from comma.database.driver import DatabaseDriver from comma.database.driver import DatabaseDriver
from comma.database.model import Distros, MonitoringSubjects from comma.database.model import Distros, MonitoringSubjects
from comma.downstream import add_downstream_target, list_downstream from comma.downstream import Downstream
from comma.downstream.monitor import monitor_downstream
from comma.upstream import process_commits from comma.upstream import process_commits
from comma.util import config from comma.util import config
from comma.util.spreadsheet import export_commits, import_commits, update_commits from comma.util.spreadsheet import export_commits, import_commits, update_commits
@ -45,7 +44,7 @@ def run(options):
if options.downstream: if options.downstream:
LOGGER.info("Begin monitoring downstream") LOGGER.info("Begin monitoring downstream")
monitor_downstream() Downstream(config, DatabaseDriver()).monitor_downstream()
LOGGER.info("Finishing monitoring downstream") LOGGER.info("Finishing monitoring downstream")
@ -75,13 +74,14 @@ def main(args: Optional[Sequence[str]] = None):
print(f" {commit}") print(f" {commit}")
if options.subcommand == "downstream": if options.subcommand == "downstream":
downstream = Downstream(config, DatabaseDriver())
# Print current targets in database # Print current targets in database
if options.action in {"list", None}: if options.action in {"list", None}:
list_downstream() downstream.list_targets()
# Add downstream target # Add downstream target
if options.action == "add": if options.action == "add":
add_downstream_target(options) downstream.add_target(options.name, options.url, options.revision)
if options.subcommand == "run": if options.subcommand == "run":
run(options) run(options)

Просмотреть файл

@ -13,7 +13,7 @@ from contextlib import contextmanager
import sqlalchemy import sqlalchemy
from comma.database.model import Base from comma.database.model import Base, MonitoringSubjects
from comma.util import config from comma.util import config
@ -89,3 +89,35 @@ class DatabaseDriver:
raise raise
finally: finally:
session.close() session.close()
def update_revisions_for_distro(self, distro_id, revs):
"""
Updates the database with the given revisions
new_revisions: list of <revision>s to add under this distro_id
"""
with self.get_session() as session:
revs_to_delete = (
session.query(MonitoringSubjects)
.filter_by(distroID=distro_id)
.filter(~MonitoringSubjects.revision.in_(revs))
)
for subject in revs_to_delete:
LOGGER.info("For distro %s, deleting revision: %s", distro_id, subject.revision)
# This is a bulk delete and we close the session immediately after.
revs_to_delete.delete(synchronize_session=False)
with self.get_session() as session:
for rev in revs:
# Only add if it doesn't already exist. We're dealing
# with revisions on the scale of 1, so the number of
# queries and inserts here doesn't matter.
if (
session.query(MonitoringSubjects)
.filter_by(distroID=distro_id, revision=rev)
.first()
is None
):
LOGGER.info("For distro %s, adding revision: %s", distro_id, rev)
session.add(MonitoringSubjects(distroID=distro_id, revision=rev))

Просмотреть файл

@ -6,18 +6,44 @@ Operations for downstream targets
import logging import logging
import sys import sys
from functools import cached_property
from comma.database.driver import DatabaseDriver from comma.database.model import (
from comma.database.model import Distros, MonitoringSubjects Distros,
MonitoringSubjects,
MonitoringSubjectsMissingPatches,
PatchData,
)
from comma.downstream.matcher import patch_matches
from comma.upstream import process_commits
from comma.util.tracking import get_linux_repo
LOGGER = logging.getLogger(__name__.split(".", 1)[0]) LOGGER = logging.getLogger(__name__.split(".", 1)[0])
def list_downstream(): class Downstream:
"""
Parent object for downstream operations
"""
def __init__(self, config, database) -> None:
self.config = config
self.database = database
@cached_property
def repo(self):
"""
Get repo when first accessed
"""
return get_linux_repo()
def list_targets(
self,
):
"""List downstream targets""" """List downstream targets"""
with DatabaseDriver.get_session() as session: with self.database.get_session() as session:
for distro, revision in ( for distro, revision in (
session.query(Distros.distroID, MonitoringSubjects.revision) session.query(Distros.distroID, MonitoringSubjects.revision)
.outerjoin(MonitoringSubjects, Distros.distroID == MonitoringSubjects.distroID) .outerjoin(MonitoringSubjects, Distros.distroID == MonitoringSubjects.distroID)
@ -25,24 +51,175 @@ def list_downstream():
): ):
print(f"{distro}\t{revision}") print(f"{distro}\t{revision}")
def add_target(self, name, url, revision):
def add_downstream_target(options):
""" """
Add a downstream target Add a downstream target
""" """
with DatabaseDriver.get_session() as session: with self.database.get_session() as session:
# Add repo # Add repo
if options.url: if url:
session.add(Distros(distroID=options.name, repoLink=options.url)) session.add(Distros(distroID=name, repoLink=url))
LOGGER.info("Successfully added new repo %s at %s", options.name, options.url) LOGGER.info("Successfully added new repo %s at %s", name, url)
# If URL wasn't given, make sure repo is in database # If URL wasn't given, make sure repo is in database
elif (options.name,) not in session.query(Distros.distroID).all(): elif (name,) not in session.query(Distros.distroID).all():
sys.exit(f"Repository '{options.name}' given without URL not found in database") sys.exit(f"Repository '{name}' given without URL not found in database")
# Add target # Add target
session.add(MonitoringSubjects(distroID=options.name, revision=options.revision)) session.add(MonitoringSubjects(distroID=name, revision=revision))
LOGGER.info("Successfully added new revision '%s' for distro '%s'", revision, name)
def monitor_downstream(self):
"""
Cycle through downstream remotes and search for missing commits
"""
repo = self.repo
# Add repos as a remote if not already added
with self.database.get_session() as session:
for distro_id, url in session.query(Distros.distroID, Distros.repoLink).all():
# Skip Debian for now
if distro_id not in self.repo.remotes and not distro_id.startswith("Debian"):
LOGGER.debug("Adding remote %s from %s", distro_id, url)
repo.create_remote(distro_id, url=url)
# Update stored revisions for repos as appropriate
LOGGER.info("Updating tracked revisions for each repo.")
with self.database.get_session() as session:
for (distro_id,) in session.query(Distros.distroID).all():
self.update_tracked_revisions(distro_id)
with self.database.get_session() as session:
subjects = session.query(MonitoringSubjects).all()
total = len(subjects)
for num, subject in enumerate(subjects, 1):
if subject.distroID.startswith("Debian"):
# TODO (Issue 51): Don't skip Debian
LOGGER.debug("skipping Debian")
continue
# Use distro name for local refs to prevent duplicates
if subject.revision.startswith(f"{subject.distroID}/"):
local_ref = subject.revision
remote_ref = subject.revision.split("/", 1)[-1]
else:
local_ref = f"{subject.distroID}/{subject.revision}"
remote_ref = subject.revision
LOGGER.info( LOGGER.info(
"Successfully added new revision '%s' for distro '%s'", options.revision, options.name "(%d of %d) Fetching remote ref %s from remote %s",
num,
total,
remote_ref,
subject.distroID,
) )
repo.fetch_remote_ref(subject.distroID, local_ref, remote_ref)
LOGGER.info(
"(%d of %d) Monitoring Script starting for distro: %s, revision: %s",
num,
total,
subject.distroID,
remote_ref,
)
self.monitor_subject(subject, local_ref)
def monitor_subject(self, monitoring_subject, reference: str):
"""
Update the missing patches in the database for this monitoring_subject
monitoring_subject: The MonitoringSubject we are updating
reference: Git reference to monitor
"""
missing_cherries = self.repo.get_missing_cherries(reference, self.repo.get_tracked_paths())
LOGGER.debug("Found %d missing patches through cherry-pick.", len(missing_cherries))
# Run extra checks on these missing commits
missing_patch_ids = self.get_missing_patch_ids(missing_cherries, reference)
LOGGER.info("Identified %d missing patches", len(missing_patch_ids))
# Delete patches that are no longer missing.
# NOTE: We do this in separate sessions in order to cleanly expire their objects and commit
# the changes to the database. There is surely another way to do this, but it works.
subject_id = monitoring_subject.monitoringSubjectID
with self.database.get_session() as session:
patches = session.query(MonitoringSubjectsMissingPatches).filter_by(
monitoringSubjectID=subject_id
)
# Delete patches that are no longer missing: the patchID is
# NOT IN the latest set of missing patchIDs.
patches_to_delete = patches.filter(
~MonitoringSubjectsMissingPatches.patchID.in_(missing_patch_ids)
)
LOGGER.info("Deleting %d patches that are now present.", patches_to_delete.count())
# This is a bulk delete and we close the session immediately after.
patches_to_delete.delete(synchronize_session=False)
# Add patches which are newly missing.
with self.database.get_session() as session:
patches = session.query(MonitoringSubjectsMissingPatches).filter_by(
monitoringSubjectID=subject_id
)
new_missing_patches = 0
for patch_id in missing_patch_ids:
# Only add if it doesn't already exist. We're dealing with patches on the scale
# of 100, so the number of queries and inserts here doesn't matter.
if patches.filter_by(patchID=patch_id).first() is None:
new_missing_patches += 1
session.add(
MonitoringSubjectsMissingPatches(
monitoringSubjectID=subject_id, patchID=patch_id
)
)
LOGGER.info("Adding %d patches that are now missing.", new_missing_patches)
def get_missing_patch_ids(self, missing_cherries, reference):
"""
Attempt to determine which patches are missing from a list of missing cherries
"""
with self.database.get_session() as session:
patches = (
session.query(PatchData)
.filter(PatchData.commitID.in_(missing_cherries))
.order_by(PatchData.commitTime)
.all()
)
# We only want to check downstream patches as old as the oldest upstream missing patch
earliest_commit_date = min(patch.commitTime for patch in patches).isoformat()
LOGGER.debug("Processing commits since %s", earliest_commit_date)
# Get the downstream commits for this revision (these are distinct from upstream because
# they have been cherry-picked). This is slow but necessary!
downstream_patches = process_commits(revision=reference, since=earliest_commit_date)
# Double check the missing cherries using our fuzzy algorithm.
LOGGER.info("Starting confidence matching for %d upstream patches...", len(patches))
missing_patches = [
p.patchID for p in patches if not patch_matches(downstream_patches, p)
]
return missing_patches
def update_tracked_revisions(self, distro_id):
"""
This updates the stored two latest revisions stored per distro_id.
This method contains distro-specific logic
repo: the git repo object of whatever repo to check revisions in
"""
# This sorts alphabetically and not by the actual date
# While technically wrong, this is preferred
# ls-remote could naturally sort by date, but that would require all the objects to be local
if distro_id.startswith("Ubuntu"):
tag_names = tuple(
tag
for tag in self.repo.get_remote_tags(distro_id)
if "azure" in tag and all(label not in tag for label in ("edge", "cvm", "fde"))
)
self.database.update_revisions_for_distro(distro_id, tag_names[-2:])

Просмотреть файл

@ -1,212 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Functions for monitoring downstream repos for missing commits
"""
import logging
from comma.database.driver import DatabaseDriver
from comma.database.model import (
Distros,
MonitoringSubjects,
MonitoringSubjectsMissingPatches,
PatchData,
)
from comma.downstream.matcher import patch_matches
from comma.upstream import process_commits
from comma.util.tracking import get_linux_repo
LOGGER = logging.getLogger(__name__)
def update_revisions_for_distro(distro_id, revs):
"""
Updates the database with the given revisions
new_revisions: list of <revision>s to add under this distro_id
"""
with DatabaseDriver.get_session() as session:
revs_to_delete = (
session.query(MonitoringSubjects)
.filter_by(distroID=distro_id)
.filter(~MonitoringSubjects.revision.in_(revs))
)
for subject in revs_to_delete:
LOGGER.info("For distro %s, deleting revision: %s", distro_id, subject.revision)
# This is a bulk delete and we close the session immediately after.
revs_to_delete.delete(synchronize_session=False)
with DatabaseDriver.get_session() as session:
for rev in revs:
# Only add if it doesn't already exist. We're dealing
# with revisions on the scale of 1, so the number of
# queries and inserts here doesn't matter.
if (
session.query(MonitoringSubjects)
.filter_by(distroID=distro_id, revision=rev)
.first()
is None
):
LOGGER.info("For distro %s, adding revision: %s", distro_id, rev)
session.add(MonitoringSubjects(distroID=distro_id, revision=rev))
def update_tracked_revisions(distro_id, repo):
"""
This updates the stored two latest revisions stored per distro_id.
This method contains distro-specific logic
repo: the git repo object of whatever repo to check revisions in
"""
# This sorts alphabetically and not by the actual date
# While technically wrong, this is preferred
# ls-remote could naturally sort by date, but that would require all the objects to be local
if distro_id.startswith("Ubuntu"):
tag_names = tuple(
tag
for tag in repo.get_remote_tags(distro_id)
if "azure" in tag and all(label not in tag for label in ("edge", "cvm", "fde"))
)
update_revisions_for_distro(distro_id, tag_names[-2:])
def monitor_subject(monitoring_subject, repo, reference=None):
"""
Update the missing patches in the database for this monitoring_subject
monitoring_subject: The MonitoringSubject we are updating
repo: The git repo object pointing to relevant upstream Linux repo
"""
reference = monitoring_subject.revision if reference is None else reference
missing_cherries = repo.get_missing_cherries(reference, repo.get_tracked_paths())
LOGGER.debug("Found %d missing patches through cherry-pick.", len(missing_cherries))
# Run extra checks on these missing commits
missing_patch_ids = get_missing_patch_ids(missing_cherries, reference)
LOGGER.info("Identified %d missing patches", len(missing_patch_ids))
# Delete patches that are no longer missing.
# NOTE: We do this in separate sessions in order to cleanly expire their objects and commit the
# changes to the database. There is surely another way to do this, but it works.
subject_id = monitoring_subject.monitoringSubjectID
with DatabaseDriver.get_session() as session:
patches = session.query(MonitoringSubjectsMissingPatches).filter_by(
monitoringSubjectID=subject_id
)
# Delete patches that are no longer missing: the patchID is
# NOT IN the latest set of missing patchIDs.
patches_to_delete = patches.filter(
~MonitoringSubjectsMissingPatches.patchID.in_(missing_patch_ids)
)
LOGGER.info("Deleting %d patches that are now present.", patches_to_delete.count())
# This is a bulk delete and we close the session immediately after.
patches_to_delete.delete(synchronize_session=False)
# Add patches which are newly missing.
with DatabaseDriver.get_session() as session:
patches = session.query(MonitoringSubjectsMissingPatches).filter_by(
monitoringSubjectID=subject_id
)
new_missing_patches = 0
for patch_id in missing_patch_ids:
# Only add if it doesn't already exist. We're dealing with patches on the scale of 100,
# so the number of queries and inserts here doesn't matter.
if patches.filter_by(patchID=patch_id).first() is None:
new_missing_patches += 1
session.add(
MonitoringSubjectsMissingPatches(
monitoringSubjectID=subject_id, patchID=patch_id
)
)
LOGGER.info("Adding %d patches that are now missing.", new_missing_patches)
def get_missing_patch_ids(missing_cherries, reference):
"""
Attempt to determine which patches are missing from a list of missing cherries
"""
with DatabaseDriver.get_session() as session:
patches = (
session.query(PatchData)
.filter(PatchData.commitID.in_(missing_cherries))
.order_by(PatchData.commitTime)
.all()
)
# We only want to check downstream patches as old as the oldest upstream missing patch
earliest_commit_date = min(patch.commitTime for patch in patches).isoformat()
LOGGER.debug("Processing commits since %s", earliest_commit_date)
# Get the downstream commits for this revision (these are distinct from upstream because
# they have been cherry-picked). This is slow but necessary!
downstream_patches = process_commits(revision=reference, since=earliest_commit_date)
# Double check the missing cherries using our fuzzy algorithm.
LOGGER.info("Starting confidence matching for %d upstream patches...", len(patches))
missing_patches = [p.patchID for p in patches if not patch_matches(downstream_patches, p)]
return missing_patches
def monitor_downstream():
"""
Cycle through downstream remotes and search for missing commits
"""
repo = get_linux_repo()
# Add repos as a remote if not already added
with DatabaseDriver.get_session() as session:
for distro_id, url in session.query(Distros.distroID, Distros.repoLink).all():
# Skip Debian for now
if distro_id not in repo.remotes and not distro_id.startswith("Debian"):
LOGGER.debug("Adding remote %s from %s", distro_id, url)
repo.create_remote(distro_id, url=url)
# Update stored revisions for repos as appropriate
LOGGER.info("Updating tracked revisions for each repo.")
with DatabaseDriver.get_session() as session:
for (distro_id,) in session.query(Distros.distroID).all():
update_tracked_revisions(distro_id, repo)
with DatabaseDriver.get_session() as session:
subjects = session.query(MonitoringSubjects).all()
total = len(subjects)
for num, subject in enumerate(subjects, 1):
if subject.distroID.startswith("Debian"):
# TODO (Issue 51): Don't skip Debian
LOGGER.debug("skipping Debian")
continue
# Use distro name for local refs to prevent duplicates
if subject.revision.startswith(f"{subject.distroID}/"):
local_ref = subject.revision
remote_ref = subject.revision.split("/", 1)[-1]
else:
local_ref = f"{subject.distroID}/{subject.revision}"
remote_ref = subject.revision
LOGGER.info(
"(%d of %d) Fetching remote ref %s from remote %s",
num,
total,
remote_ref,
subject.distroID,
)
repo.fetch_remote_ref(subject.distroID, local_ref, remote_ref)
LOGGER.info(
"(%d of %d) Monitoring Script starting for distro: %s, revision: %s",
num,
total,
subject.distroID,
remote_ref,
)
monitor_subject(subject, repo, local_ref)