Move symbol operations to class

This commit is contained in:
Avram Lubkin 2023-06-15 09:08:22 -04:00 коммит произвёл Avram Lubkin
Родитель 393a70ed47
Коммит 91f113f15f
2 изменённых файлов: 92 добавлений и 80 удалений

Просмотреть файл

@ -14,7 +14,7 @@ from comma.downstream import Downstream
from comma.upstream import Upstream
from comma.util import config
from comma.util.spreadsheet import export_commits, import_commits, update_commits
from comma.util.symbols import get_missing_commits
from comma.util.symbols import Symbols
from comma.util.tracking import get_linux_repo
@ -68,7 +68,7 @@ def main(args: Optional[Sequence[str]] = None):
# TODO(Issue 25: resolve configuration
if options.subcommand == "symbols":
missing = get_missing_commits(options.file)
missing = Symbols(config, DatabaseDriver()).get_missing_commits(options.file)
print("Missing symbols from:")
for commit in missing:
print(f" {commit}")

Просмотреть файл

@ -6,11 +6,11 @@ Functions for generating symbol maps
import logging
import subprocess
from functools import cached_property
from pathlib import Path
from typing import Iterable
from comma.database.driver import DatabaseDriver
from comma.database.model import PatchData
from comma.util import config
from comma.util.tracking import get_linux_repo
@ -21,7 +21,7 @@ def get_symbols(repo_dir, files):
"""
Returns a set of symbols for given files
files: iterable of files
returns: set of symbols generated through ctags
returns set of symbols generated through ctags
"""
command = "ctags -R -x ckinds=f {}".format(
" ".join(files) + " | awk '{ if ($2 == \"function\") print $1 }'"
@ -39,99 +39,111 @@ def get_symbols(repo_dir, files):
return set(process.stdout.splitlines())
def map_symbols_to_patch(
repo, commits, files, prev_commit="097c1bd5673edaf2a162724636858b71f658fdd2"
):
class Symbols:
"""
This function generates and stores symbols generated by each patch
repo: git repo object
files: hyperV files
commits: SHA of all commits in database
prev_commit: SHA of start of HyperV patch to track
Parent object for symbol operations
"""
LOGGER.info("Mapping symbols to commits")
def __init__(self, config, database) -> None:
self.config = config
self.database = database
# Preserve initial reference
initial_reference = repo.head.reference
@cached_property
def repo(self):
"""
Get repo when first accessed
"""
return get_linux_repo(name="linux-sym", pull=True)
try:
repo.checkout(prev_commit)
before_patch_apply = None
def get_missing_commits(self, symbol_file):
"""Returns a sorted list of commit IDs whose symbols are missing from the given file"""
# Iterate through commits
for commit in commits:
# Get symbols before patch is applied
if before_patch_apply is None:
before_patch_apply = get_symbols(repo.working_tree_dir, files)
LOGGER.info("Starting Symbol Checker")
self.get_patch_symbols()
LOGGER.info("Detecting missing symbols")
return self.symbol_checker(symbol_file)
# Checkout commit
repo.checkout(commit)
def get_patch_symbols(self):
"""
This function clones upstream and gets upstream commits
"""
# Get symbols after patch is applied
after_patch_apply = get_symbols(repo.working_tree_dir, files)
with self.database.get_session() as session:
# SQLAlchemy returns tuples which need to be unwrapped
self.map_symbols_to_patch(
[
commit[0]
for commit in session.query(PatchData.commitID)
.order_by(PatchData.commitTime)
.all()
],
self.repo.get_tracked_paths(self.config.sections),
)
# Compare symbols before and after patch
diff_symbols = after_patch_apply - before_patch_apply
if diff_symbols:
print(f"Commit: {commit} -> {' '.join(diff_symbols)}")
def map_symbols_to_patch(
self, commits: Iterable[str], paths, prev_commit="097c1bd5673edaf2a162724636858b71f658fdd2"
):
"""
This function generates and stores symbols generated by each patch
repo: git repo object
files: hyperV files
commits: SHA of all commits in database
prev_commit: SHA of start of HyperV patch to track
"""
# Save symbols to database
with DatabaseDriver.get_session() as session:
patch = session.query(PatchData).filter_by(commitID=commit).one()
patch.symbols = " ".join(diff_symbols)
LOGGER.info("Mapping symbols to commits")
# Use symbols from current commit to compare to next commit
before_patch_apply = after_patch_apply
# Preserve initial reference
initial_reference = self.repo.head.reference
finally:
# Reset reference
repo.checkout(initial_reference)
try:
self.repo.checkout(prev_commit)
before_patch_apply = None
# Iterate through commits
for commit in commits:
# Get symbols before patch is applied
if before_patch_apply is None:
before_patch_apply = get_symbols(self.repo.working_tree_dir, paths)
def get_hyperv_patch_symbols():
"""
This function clones upstream and gets upstream commits, hyperV files
"""
# Checkout commit
self.repo.checkout(commit)
repo = get_linux_repo(name="linux-sym", pull=True)
# Get symbols after patch is applied
after_patch_apply = get_symbols(self.repo.working_tree_dir, paths)
with DatabaseDriver.get_session() as session:
# SQLAlchemy returns tuples which need to be unwrapped
map_symbols_to_patch(
repo,
[
commit[0]
for commit in session.query(PatchData.commitID).order_by(PatchData.commitTime).all()
],
repo.get_tracked_paths(config.sections),
)
# Compare symbols before and after patch
diff_symbols = after_patch_apply - before_patch_apply
if diff_symbols:
print(f"Commit: {commit} -> {' '.join(diff_symbols)}")
# Save symbols to database
with self.database.get_session() as session:
patch = session.query(PatchData).filter_by(commitID=commit).one()
patch.symbols = " ".join(diff_symbols)
def symbol_checker(filepath: Path):
"""
This function returns missing symbols by comparing database patch symbols with given symbols
symbol_file: file containing symbols to run against database
return missing_symbols_patch: list of missing symbols from given list
"""
with open(filepath, "r", encoding="utf-8") as symbol_file:
symbols_in_file = {line.strip() for line in symbol_file}
# Use symbols from current commit to compare to next commit
before_patch_apply = after_patch_apply
with DatabaseDriver.get_session() as session:
return sorted(
commitID
for commitID, symbols in session.query(PatchData.commitID, PatchData.symbols)
.filter(PatchData.symbols != " ")
.order_by(PatchData.commitTime)
.all()
if len(set(symbols.split(" ")) - symbols_in_file) > 0
)
finally:
# Reset reference
self.repo.checkout(initial_reference)
def symbol_checker(self, file_path: Path):
"""
This function returns missing symbols by comparing database patch symbols with given symbols
file_path: file containing symbols to run against database
returns sorted list of commits whose symbols are missing from file
"""
with open(file_path, "r", encoding="utf-8") as symbol_file:
symbols_in_file = {line.strip() for line in symbol_file}
def get_missing_commits(symbol_file):
"""Returns a sorted list of commit IDs whose symbols are missing from the given file"""
LOGGER.info("Starting Symbol Checker")
get_hyperv_patch_symbols()
LOGGER.info("Detecting missing symbols")
return symbol_checker(symbol_file)
with self.database.get_session() as session:
return sorted(
commitID
for commitID, symbols in session.query(PatchData.commitID, PatchData.symbols)
.filter(PatchData.symbols != " ")
.order_by(PatchData.commitTime)
.all()
if len(set(symbols.split(" ")) - symbols_in_file) > 0
)