From 91f113f15f65f2164f4cac2a83ee0e49e4e0480d Mon Sep 17 00:00:00 2001 From: Avram Lubkin Date: Thu, 15 Jun 2023 09:08:22 -0400 Subject: [PATCH] Move symbol operations to class --- comma/cli/__init__.py | 4 +- comma/util/symbols.py | 168 ++++++++++++++++++++++-------------------- 2 files changed, 92 insertions(+), 80 deletions(-) diff --git a/comma/cli/__init__.py b/comma/cli/__init__.py index cacb82f..3771872 100644 --- a/comma/cli/__init__.py +++ b/comma/cli/__init__.py @@ -14,7 +14,7 @@ from comma.downstream import Downstream from comma.upstream import Upstream from comma.util import config from comma.util.spreadsheet import export_commits, import_commits, update_commits -from comma.util.symbols import get_missing_commits +from comma.util.symbols import Symbols from comma.util.tracking import get_linux_repo @@ -68,7 +68,7 @@ def main(args: Optional[Sequence[str]] = None): # TODO(Issue 25: resolve configuration if options.subcommand == "symbols": - missing = get_missing_commits(options.file) + missing = Symbols(config, DatabaseDriver()).get_missing_commits(options.file) print("Missing symbols from:") for commit in missing: print(f" {commit}") diff --git a/comma/util/symbols.py b/comma/util/symbols.py index 85e12cb..1e90d01 100755 --- a/comma/util/symbols.py +++ b/comma/util/symbols.py @@ -6,11 +6,11 @@ Functions for generating symbol maps import logging import subprocess +from functools import cached_property from pathlib import Path +from typing import Iterable -from comma.database.driver import DatabaseDriver from comma.database.model import PatchData -from comma.util import config from comma.util.tracking import get_linux_repo @@ -21,7 +21,7 @@ def get_symbols(repo_dir, files): """ Returns a set of symbols for given files files: iterable of files - returns: set of symbols generated through ctags + returns set of symbols generated through ctags """ command = "ctags -R -x −−c−kinds=f {}".format( " ".join(files) + " | awk '{ if ($2 == \"function\") print $1 }'" @@ -39,99 +39,111 @@ def get_symbols(repo_dir, files): return set(process.stdout.splitlines()) -def map_symbols_to_patch( - repo, commits, files, prev_commit="097c1bd5673edaf2a162724636858b71f658fdd2" -): +class Symbols: """ - This function generates and stores symbols generated by each patch - repo: git repo object - files: hyperV files - commits: SHA of all commits in database - prev_commit: SHA of start of HyperV patch to track + Parent object for symbol operations """ - LOGGER.info("Mapping symbols to commits") + def __init__(self, config, database) -> None: + self.config = config + self.database = database - # Preserve initial reference - initial_reference = repo.head.reference + @cached_property + def repo(self): + """ + Get repo when first accessed + """ + return get_linux_repo(name="linux-sym", pull=True) - try: - repo.checkout(prev_commit) - before_patch_apply = None + def get_missing_commits(self, symbol_file): + """Returns a sorted list of commit IDs whose symbols are missing from the given file""" - # Iterate through commits - for commit in commits: - # Get symbols before patch is applied - if before_patch_apply is None: - before_patch_apply = get_symbols(repo.working_tree_dir, files) + LOGGER.info("Starting Symbol Checker") + self.get_patch_symbols() + LOGGER.info("Detecting missing symbols") + return self.symbol_checker(symbol_file) - # Checkout commit - repo.checkout(commit) + def get_patch_symbols(self): + """ + This function clones upstream and gets upstream commits + """ - # Get symbols after patch is applied - after_patch_apply = get_symbols(repo.working_tree_dir, files) + with self.database.get_session() as session: + # SQLAlchemy returns tuples which need to be unwrapped + self.map_symbols_to_patch( + [ + commit[0] + for commit in session.query(PatchData.commitID) + .order_by(PatchData.commitTime) + .all() + ], + self.repo.get_tracked_paths(self.config.sections), + ) - # Compare symbols before and after patch - diff_symbols = after_patch_apply - before_patch_apply - if diff_symbols: - print(f"Commit: {commit} -> {' '.join(diff_symbols)}") + def map_symbols_to_patch( + self, commits: Iterable[str], paths, prev_commit="097c1bd5673edaf2a162724636858b71f658fdd2" + ): + """ + This function generates and stores symbols generated by each patch + repo: git repo object + files: hyperV files + commits: SHA of all commits in database + prev_commit: SHA of start of HyperV patch to track + """ - # Save symbols to database - with DatabaseDriver.get_session() as session: - patch = session.query(PatchData).filter_by(commitID=commit).one() - patch.symbols = " ".join(diff_symbols) + LOGGER.info("Mapping symbols to commits") - # Use symbols from current commit to compare to next commit - before_patch_apply = after_patch_apply + # Preserve initial reference + initial_reference = self.repo.head.reference - finally: - # Reset reference - repo.checkout(initial_reference) + try: + self.repo.checkout(prev_commit) + before_patch_apply = None + # Iterate through commits + for commit in commits: + # Get symbols before patch is applied + if before_patch_apply is None: + before_patch_apply = get_symbols(self.repo.working_tree_dir, paths) -def get_hyperv_patch_symbols(): - """ - This function clones upstream and gets upstream commits, hyperV files - """ + # Checkout commit + self.repo.checkout(commit) - repo = get_linux_repo(name="linux-sym", pull=True) + # Get symbols after patch is applied + after_patch_apply = get_symbols(self.repo.working_tree_dir, paths) - with DatabaseDriver.get_session() as session: - # SQLAlchemy returns tuples which need to be unwrapped - map_symbols_to_patch( - repo, - [ - commit[0] - for commit in session.query(PatchData.commitID).order_by(PatchData.commitTime).all() - ], - repo.get_tracked_paths(config.sections), - ) + # Compare symbols before and after patch + diff_symbols = after_patch_apply - before_patch_apply + if diff_symbols: + print(f"Commit: {commit} -> {' '.join(diff_symbols)}") + # Save symbols to database + with self.database.get_session() as session: + patch = session.query(PatchData).filter_by(commitID=commit).one() + patch.symbols = " ".join(diff_symbols) -def symbol_checker(filepath: Path): - """ - This function returns missing symbols by comparing database patch symbols with given symbols - symbol_file: file containing symbols to run against database - return missing_symbols_patch: list of missing symbols from given list - """ - with open(filepath, "r", encoding="utf-8") as symbol_file: - symbols_in_file = {line.strip() for line in symbol_file} + # Use symbols from current commit to compare to next commit + before_patch_apply = after_patch_apply - with DatabaseDriver.get_session() as session: - return sorted( - commitID - for commitID, symbols in session.query(PatchData.commitID, PatchData.symbols) - .filter(PatchData.symbols != " ") - .order_by(PatchData.commitTime) - .all() - if len(set(symbols.split(" ")) - symbols_in_file) > 0 - ) + finally: + # Reset reference + self.repo.checkout(initial_reference) + def symbol_checker(self, file_path: Path): + """ + This function returns missing symbols by comparing database patch symbols with given symbols + file_path: file containing symbols to run against database + returns sorted list of commits whose symbols are missing from file + """ + with open(file_path, "r", encoding="utf-8") as symbol_file: + symbols_in_file = {line.strip() for line in symbol_file} -def get_missing_commits(symbol_file): - """Returns a sorted list of commit IDs whose symbols are missing from the given file""" - - LOGGER.info("Starting Symbol Checker") - get_hyperv_patch_symbols() - LOGGER.info("Detecting missing symbols") - return symbol_checker(symbol_file) + with self.database.get_session() as session: + return sorted( + commitID + for commitID, symbols in session.query(PatchData.commitID, PatchData.symbols) + .filter(PatchData.symbols != " ") + .order_by(PatchData.commitTime) + .all() + if len(set(symbols.split(" ")) - symbols_in_file) > 0 + )