Move repo fetching to top level
This commit is contained in:
Родитель
42cb9247af
Коммит
6ac4ae1cb0
|
@ -20,50 +20,132 @@ from comma.downstream import Downstream
|
|||
from comma.upstream import Upstream
|
||||
from comma.util.spreadsheet import Spreadsheet
|
||||
from comma.util.symbols import Symbols
|
||||
from comma.util.tracking import get_linux_repo
|
||||
from comma.util.tracking import Repo
|
||||
|
||||
|
||||
LOGGER = logging.getLogger("comma.cli")
|
||||
YAML = R_YAML(typ="safe")
|
||||
|
||||
|
||||
def run(options, config, database):
|
||||
class Session:
|
||||
"""
|
||||
Handle run subcommand
|
||||
Container for session data to avoid duplicate actions
|
||||
"""
|
||||
|
||||
if options.dry_run:
|
||||
# Populate database from configuration file
|
||||
with database.get_session() as session:
|
||||
if session.query(Distros).first() is None:
|
||||
session.add_all(
|
||||
Distros(distroID=name, repoLink=url) for name, url in config.repos.items()
|
||||
)
|
||||
def __init__(self, config, database) -> None:
|
||||
self.config: Config = config
|
||||
self.database: DatabaseDriver = database
|
||||
|
||||
if session.query(MonitoringSubjects).first() is None:
|
||||
session.add_all(
|
||||
MonitoringSubjects(distroID=target.repo, revision=target.reference)
|
||||
for target in config.downstream
|
||||
)
|
||||
def _get_repo(
|
||||
self,
|
||||
since: Optional[str] = None,
|
||||
pull: bool = False,
|
||||
suffix: Optional[str] = None,
|
||||
) -> Repo:
|
||||
"""
|
||||
Clone or update a repo
|
||||
"""
|
||||
|
||||
if options.print_tracked_paths:
|
||||
for path in get_linux_repo(config.upstream_since).get_tracked_paths(
|
||||
config.upstream.sections
|
||||
):
|
||||
print(path)
|
||||
name = self.config.upstream.repo
|
||||
if suffix:
|
||||
name += f"-{suffix}"
|
||||
repo = Repo(name, self.config.repos[self.config.upstream.repo])
|
||||
|
||||
if options.upstream:
|
||||
LOGGER.info("Begin monitoring upstream")
|
||||
Upstream(config, database).process_commits()
|
||||
LOGGER.info("Finishing monitoring upstream")
|
||||
if not repo.exists:
|
||||
# No local repo, clone from source
|
||||
repo.clone(since)
|
||||
|
||||
if options.downstream:
|
||||
LOGGER.info("Begin monitoring downstream")
|
||||
Downstream(config, database).monitor()
|
||||
LOGGER.info("Finishing monitoring downstream")
|
||||
elif pull:
|
||||
repo.pull()
|
||||
|
||||
else:
|
||||
repo.fetch(since)
|
||||
|
||||
return repo
|
||||
|
||||
def run(self, options):
|
||||
"""
|
||||
Handle run subcommand
|
||||
"""
|
||||
|
||||
if options.dry_run:
|
||||
# Populate database from configuration file
|
||||
with self.database.get_session() as session:
|
||||
if session.query(Distros).first() is None:
|
||||
session.add_all(
|
||||
Distros(distroID=name, repoLink=url)
|
||||
for name, url in self.config.repos.items()
|
||||
)
|
||||
|
||||
if session.query(MonitoringSubjects).first() is None:
|
||||
session.add_all(
|
||||
MonitoringSubjects(distroID=target.repo, revision=target.reference)
|
||||
for target in self.config.downstream
|
||||
)
|
||||
|
||||
repo = self._get_repo(since=self.config.upstream_since)
|
||||
|
||||
if options.print_tracked_paths:
|
||||
for path in repo.get_tracked_paths(self.config.upstream.sections):
|
||||
print(path)
|
||||
|
||||
if options.upstream:
|
||||
LOGGER.info("Begin monitoring upstream")
|
||||
Upstream(self.config, self.database, repo).process_commits()
|
||||
LOGGER.info("Finishing monitoring upstream")
|
||||
|
||||
if options.downstream:
|
||||
LOGGER.info("Begin monitoring downstream")
|
||||
Downstream(self.config, self.database, repo).monitor()
|
||||
LOGGER.info("Finishing monitoring downstream")
|
||||
|
||||
def symbols(self, options):
|
||||
"""
|
||||
Handle symbols subcommand
|
||||
"""
|
||||
repo = self._get_repo(pull=True, suffix="sym")
|
||||
|
||||
missing = Symbols(self.config, self.database, repo).get_missing_commits(options.file)
|
||||
print("Missing symbols from:")
|
||||
for commit in missing:
|
||||
print(f" {commit}")
|
||||
|
||||
def downstream(self, options):
|
||||
"""
|
||||
Handle downstream subcommand
|
||||
"""
|
||||
|
||||
# Print current targets in database
|
||||
if options.action in {"list", None}:
|
||||
for remote, reference in self.database.iter_downstream_targets():
|
||||
print(f"{remote}\t{reference}")
|
||||
|
||||
# Add downstream target
|
||||
if options.action == "add":
|
||||
self.database.add_downstream_target(options.name, options.url, options.revision)
|
||||
|
||||
def spreadsheet(self, options):
|
||||
"""
|
||||
Handle spreadsheet subcommand
|
||||
"""
|
||||
|
||||
repo = self._get_repo(since=self.config.upstream_since)
|
||||
spreadsheet = Spreadsheet(self.config, self.database, repo)
|
||||
|
||||
if options.export_commits:
|
||||
spreadsheet.export_commits(options.in_file, options.out_file)
|
||||
if options.update_commits:
|
||||
spreadsheet.update_commits(options.in_file, options.out_file)
|
||||
|
||||
def __call__(self, options) -> None:
|
||||
"""
|
||||
Runs the specified subcommand
|
||||
"""
|
||||
|
||||
getattr(self, options.subcommand)(options)
|
||||
|
||||
|
||||
def main(args: Optional[Sequence[str]] = None): # pylint: disable=too-many-branches
|
||||
def main(args: Optional[Sequence[str]] = None):
|
||||
"""
|
||||
Main CLI entry point
|
||||
"""
|
||||
|
@ -93,33 +175,11 @@ def main(args: Optional[Sequence[str]] = None): # pylint: disable=too-many-bran
|
|||
if hasattr(options, option):
|
||||
setattr(config, option, getattr(options, option))
|
||||
|
||||
# Get database object
|
||||
database = DatabaseDriver(dry_run=options.dry_run, echo=options.verbose > 2)
|
||||
|
||||
if options.subcommand == "symbols":
|
||||
missing = Symbols(config, database).get_missing_commits(options.file)
|
||||
print("Missing symbols from:")
|
||||
for commit in missing:
|
||||
print(f" {commit}")
|
||||
|
||||
if options.subcommand == "downstream":
|
||||
# Print current targets in database
|
||||
if options.action in {"list", None}:
|
||||
for remote, reference in database.iter_downstream_targets():
|
||||
print(f"{remote}\t{reference}")
|
||||
|
||||
# Add downstream target
|
||||
if options.action == "add":
|
||||
database.add_downstream_target(options.name, options.url, options.revision)
|
||||
|
||||
if options.subcommand == "run":
|
||||
run(options, config, database)
|
||||
|
||||
if options.subcommand == "spreadsheet":
|
||||
spreadsheet = Spreadsheet(config, database)
|
||||
if args.export_commits:
|
||||
spreadsheet.export_commits(args.in_file, args.out_file)
|
||||
if args.update_commits:
|
||||
spreadsheet.update_commits(args.in_file, args.out_file)
|
||||
# Create session object and invoke subcommand
|
||||
Session(config, database)(options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -5,7 +5,6 @@ Operations for downstream targets
|
|||
"""
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from comma.database.model import (
|
||||
Distros,
|
||||
|
@ -14,7 +13,6 @@ from comma.database.model import (
|
|||
PatchData,
|
||||
)
|
||||
from comma.downstream.matcher import patch_matches
|
||||
from comma.util.tracking import get_linux_repo
|
||||
|
||||
|
||||
LOGGER = logging.getLogger(__name__.split(".", 1)[0])
|
||||
|
@ -25,16 +23,10 @@ class Downstream:
|
|||
Parent object for downstream operations
|
||||
"""
|
||||
|
||||
def __init__(self, config, database) -> None:
|
||||
def __init__(self, config, database, repo) -> None:
|
||||
self.config = config
|
||||
self.database = database
|
||||
|
||||
@cached_property
|
||||
def repo(self):
|
||||
"""
|
||||
Get repo when first accessed
|
||||
"""
|
||||
return get_linux_repo(since=self.config.upstream_since)
|
||||
self.repo = repo
|
||||
|
||||
def monitor(self):
|
||||
"""
|
||||
|
|
|
@ -5,10 +5,8 @@ Functions for parsing commit objects into patch objects
|
|||
"""
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from comma.database.model import PatchData
|
||||
from comma.util.tracking import get_linux_repo
|
||||
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
@ -19,16 +17,10 @@ class Upstream:
|
|||
Parent object for downstream operations
|
||||
"""
|
||||
|
||||
def __init__(self, config, database) -> None:
|
||||
def __init__(self, config, database, repo) -> None:
|
||||
self.config = config
|
||||
self.database = database
|
||||
|
||||
@cached_property
|
||||
def repo(self):
|
||||
"""
|
||||
Get repo when first accessed
|
||||
"""
|
||||
return get_linux_repo(since=self.config.upstream_since)
|
||||
self.repo = repo
|
||||
|
||||
def process_commits(self):
|
||||
"""
|
||||
|
|
|
@ -8,7 +8,6 @@ import logging
|
|||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
|
@ -20,7 +19,7 @@ from openpyxl.workbook.workbook import Workbook
|
|||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
|
||||
from comma.database.model import Distros, MonitoringSubjects, PatchData
|
||||
from comma.util.tracking import get_filenames, get_linux_repo
|
||||
from comma.util.tracking import get_filenames
|
||||
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
@ -72,16 +71,10 @@ class Spreadsheet:
|
|||
Parent object for symbol operations
|
||||
"""
|
||||
|
||||
def __init__(self, config, database) -> None:
|
||||
def __init__(self, config, database, repo) -> None:
|
||||
self.config = config
|
||||
self.database = database
|
||||
|
||||
@cached_property
|
||||
def repo(self):
|
||||
"""
|
||||
Get repo when first accessed
|
||||
"""
|
||||
return get_linux_repo(since=self.config.upstream_since)
|
||||
self.repo = repo
|
||||
|
||||
def get_db_commits(self) -> Dict[str, int]:
|
||||
"""Query the 'PatchData' table for all commit hashes and IDs."""
|
||||
|
|
|
@ -6,12 +6,10 @@ Functions for generating symbol maps
|
|||
|
||||
import logging
|
||||
import subprocess
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from comma.database.model import PatchData
|
||||
from comma.util.tracking import get_linux_repo
|
||||
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
@ -44,16 +42,10 @@ class Symbols:
|
|||
Parent object for symbol operations
|
||||
"""
|
||||
|
||||
def __init__(self, config, database) -> None:
|
||||
def __init__(self, config, database, repo) -> None:
|
||||
self.config = config
|
||||
self.database = database
|
||||
|
||||
@cached_property
|
||||
def repo(self):
|
||||
"""
|
||||
Get repo when first accessed
|
||||
"""
|
||||
return get_linux_repo(name="linux-sym", pull=True)
|
||||
self.repo = repo
|
||||
|
||||
def get_missing_commits(self, symbol_file):
|
||||
"""Returns a sorted list of commit IDs whose symbols are missing from the given file"""
|
||||
|
|
|
@ -230,61 +230,6 @@ class Repo:
|
|||
self.obj.head.reset(index=True, working_tree=True)
|
||||
|
||||
|
||||
class Session:
|
||||
"""
|
||||
Container for session data to avoid duplicate actions
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.repos: dict = {}
|
||||
|
||||
def get_repo(
|
||||
self,
|
||||
name: str,
|
||||
url: str,
|
||||
since: Optional[str] = None,
|
||||
pull: bool = False,
|
||||
) -> Repo:
|
||||
"""
|
||||
Clone and optionally update a repo, returning the object.
|
||||
|
||||
Only clones, fetches, or pulls once per session
|
||||
"""
|
||||
|
||||
if name in self.repos:
|
||||
# Repo has been cloned, fetched, or pulled already in this session
|
||||
return self.repos[name]
|
||||
|
||||
repo = self.repos[name] = Repo(name, url)
|
||||
if not repo.exists:
|
||||
# No local repo, clone from source
|
||||
repo.clone(since)
|
||||
|
||||
elif pull:
|
||||
repo.pull()
|
||||
else:
|
||||
repo.fetch(since)
|
||||
|
||||
return repo
|
||||
|
||||
|
||||
# TODO (Issue 56): Move session creation to main program logic
|
||||
SESSION = Session()
|
||||
|
||||
|
||||
def get_linux_repo(
|
||||
name: str = "linux.git",
|
||||
url: str = "https://github.com/torvalds/linux.git",
|
||||
since: Optional[str] = None,
|
||||
pull: bool = False,
|
||||
) -> Repo:
|
||||
"""
|
||||
Shortcut for getting Linux repo
|
||||
"""
|
||||
|
||||
return SESSION.get_repo(name, url, since=since, pull=pull)
|
||||
|
||||
|
||||
def extract_paths(sections: Iterable, content: str) -> Set[str]:
|
||||
# pylint: disable=wrong-spelling-in-docstring
|
||||
"""
|
||||
|
|
Загрузка…
Ссылка в новой задаче