Move symbol operations to class

This commit is contained in:
Avram Lubkin 2023-06-15 09:08:22 -04:00 коммит произвёл Avram Lubkin
Родитель 393a70ed47
Коммит 91f113f15f
2 изменённых файлов: 92 добавлений и 80 удалений

Просмотреть файл

@ -14,7 +14,7 @@ from comma.downstream import Downstream
from comma.upstream import Upstream from comma.upstream import Upstream
from comma.util import config from comma.util import config
from comma.util.spreadsheet import export_commits, import_commits, update_commits from comma.util.spreadsheet import export_commits, import_commits, update_commits
from comma.util.symbols import get_missing_commits from comma.util.symbols import Symbols
from comma.util.tracking import get_linux_repo from comma.util.tracking import get_linux_repo
@ -68,7 +68,7 @@ def main(args: Optional[Sequence[str]] = None):
# TODO(Issue 25: resolve configuration # TODO(Issue 25: resolve configuration
if options.subcommand == "symbols": if options.subcommand == "symbols":
missing = get_missing_commits(options.file) missing = Symbols(config, DatabaseDriver()).get_missing_commits(options.file)
print("Missing symbols from:") print("Missing symbols from:")
for commit in missing: for commit in missing:
print(f" {commit}") print(f" {commit}")

Просмотреть файл

@ -6,11 +6,11 @@ Functions for generating symbol maps
import logging import logging
import subprocess import subprocess
from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Iterable
from comma.database.driver import DatabaseDriver
from comma.database.model import PatchData from comma.database.model import PatchData
from comma.util import config
from comma.util.tracking import get_linux_repo from comma.util.tracking import get_linux_repo
@ -21,7 +21,7 @@ def get_symbols(repo_dir, files):
""" """
Returns a set of symbols for given files Returns a set of symbols for given files
files: iterable of files files: iterable of files
returns: set of symbols generated through ctags returns set of symbols generated through ctags
""" """
command = "ctags -R -x ckinds=f {}".format( command = "ctags -R -x ckinds=f {}".format(
" ".join(files) + " | awk '{ if ($2 == \"function\") print $1 }'" " ".join(files) + " | awk '{ if ($2 == \"function\") print $1 }'"
@ -39,99 +39,111 @@ def get_symbols(repo_dir, files):
return set(process.stdout.splitlines()) return set(process.stdout.splitlines())
def map_symbols_to_patch( class Symbols:
repo, commits, files, prev_commit="097c1bd5673edaf2a162724636858b71f658fdd2"
):
""" """
This function generates and stores symbols generated by each patch Parent object for symbol operations
repo: git repo object
files: hyperV files
commits: SHA of all commits in database
prev_commit: SHA of start of HyperV patch to track
""" """
LOGGER.info("Mapping symbols to commits") def __init__(self, config, database) -> None:
self.config = config
self.database = database
# Preserve initial reference @cached_property
initial_reference = repo.head.reference def repo(self):
"""
Get repo when first accessed
"""
return get_linux_repo(name="linux-sym", pull=True)
try: def get_missing_commits(self, symbol_file):
repo.checkout(prev_commit) """Returns a sorted list of commit IDs whose symbols are missing from the given file"""
before_patch_apply = None
# Iterate through commits LOGGER.info("Starting Symbol Checker")
for commit in commits: self.get_patch_symbols()
# Get symbols before patch is applied LOGGER.info("Detecting missing symbols")
if before_patch_apply is None: return self.symbol_checker(symbol_file)
before_patch_apply = get_symbols(repo.working_tree_dir, files)
# Checkout commit def get_patch_symbols(self):
repo.checkout(commit) """
This function clones upstream and gets upstream commits
"""
# Get symbols after patch is applied with self.database.get_session() as session:
after_patch_apply = get_symbols(repo.working_tree_dir, files) # SQLAlchemy returns tuples which need to be unwrapped
self.map_symbols_to_patch(
[
commit[0]
for commit in session.query(PatchData.commitID)
.order_by(PatchData.commitTime)
.all()
],
self.repo.get_tracked_paths(self.config.sections),
)
# Compare symbols before and after patch def map_symbols_to_patch(
diff_symbols = after_patch_apply - before_patch_apply self, commits: Iterable[str], paths, prev_commit="097c1bd5673edaf2a162724636858b71f658fdd2"
if diff_symbols: ):
print(f"Commit: {commit} -> {' '.join(diff_symbols)}") """
This function generates and stores symbols generated by each patch
repo: git repo object
files: hyperV files
commits: SHA of all commits in database
prev_commit: SHA of start of HyperV patch to track
"""
# Save symbols to database LOGGER.info("Mapping symbols to commits")
with DatabaseDriver.get_session() as session:
patch = session.query(PatchData).filter_by(commitID=commit).one()
patch.symbols = " ".join(diff_symbols)
# Use symbols from current commit to compare to next commit # Preserve initial reference
before_patch_apply = after_patch_apply initial_reference = self.repo.head.reference
finally: try:
# Reset reference self.repo.checkout(prev_commit)
repo.checkout(initial_reference) before_patch_apply = None
# Iterate through commits
for commit in commits:
# Get symbols before patch is applied
if before_patch_apply is None:
before_patch_apply = get_symbols(self.repo.working_tree_dir, paths)
def get_hyperv_patch_symbols(): # Checkout commit
""" self.repo.checkout(commit)
This function clones upstream and gets upstream commits, hyperV files
"""
repo = get_linux_repo(name="linux-sym", pull=True) # Get symbols after patch is applied
after_patch_apply = get_symbols(self.repo.working_tree_dir, paths)
with DatabaseDriver.get_session() as session: # Compare symbols before and after patch
# SQLAlchemy returns tuples which need to be unwrapped diff_symbols = after_patch_apply - before_patch_apply
map_symbols_to_patch( if diff_symbols:
repo, print(f"Commit: {commit} -> {' '.join(diff_symbols)}")
[
commit[0]
for commit in session.query(PatchData.commitID).order_by(PatchData.commitTime).all()
],
repo.get_tracked_paths(config.sections),
)
# Save symbols to database
with self.database.get_session() as session:
patch = session.query(PatchData).filter_by(commitID=commit).one()
patch.symbols = " ".join(diff_symbols)
def symbol_checker(filepath: Path): # Use symbols from current commit to compare to next commit
""" before_patch_apply = after_patch_apply
This function returns missing symbols by comparing database patch symbols with given symbols
symbol_file: file containing symbols to run against database
return missing_symbols_patch: list of missing symbols from given list
"""
with open(filepath, "r", encoding="utf-8") as symbol_file:
symbols_in_file = {line.strip() for line in symbol_file}
with DatabaseDriver.get_session() as session: finally:
return sorted( # Reset reference
commitID self.repo.checkout(initial_reference)
for commitID, symbols in session.query(PatchData.commitID, PatchData.symbols)
.filter(PatchData.symbols != " ")
.order_by(PatchData.commitTime)
.all()
if len(set(symbols.split(" ")) - symbols_in_file) > 0
)
def symbol_checker(self, file_path: Path):
"""
This function returns missing symbols by comparing database patch symbols with given symbols
file_path: file containing symbols to run against database
returns sorted list of commits whose symbols are missing from file
"""
with open(file_path, "r", encoding="utf-8") as symbol_file:
symbols_in_file = {line.strip() for line in symbol_file}
def get_missing_commits(symbol_file): with self.database.get_session() as session:
"""Returns a sorted list of commit IDs whose symbols are missing from the given file""" return sorted(
commitID
LOGGER.info("Starting Symbol Checker") for commitID, symbols in session.query(PatchData.commitID, PatchData.symbols)
get_hyperv_patch_symbols() .filter(PatchData.symbols != " ")
LOGGER.info("Detecting missing symbols") .order_by(PatchData.commitTime)
return symbol_checker(symbol_file) .all()
if len(set(symbols.split(" ")) - symbols_in_file) > 0
)