Move symbol operations to class
This commit is contained in:
Родитель
393a70ed47
Коммит
91f113f15f
|
@ -14,7 +14,7 @@ from comma.downstream import Downstream
|
|||
from comma.upstream import Upstream
|
||||
from comma.util import config
|
||||
from comma.util.spreadsheet import export_commits, import_commits, update_commits
|
||||
from comma.util.symbols import get_missing_commits
|
||||
from comma.util.symbols import Symbols
|
||||
from comma.util.tracking import get_linux_repo
|
||||
|
||||
|
||||
|
@ -68,7 +68,7 @@ def main(args: Optional[Sequence[str]] = None):
|
|||
# TODO(Issue 25: resolve configuration
|
||||
|
||||
if options.subcommand == "symbols":
|
||||
missing = get_missing_commits(options.file)
|
||||
missing = Symbols(config, DatabaseDriver()).get_missing_commits(options.file)
|
||||
print("Missing symbols from:")
|
||||
for commit in missing:
|
||||
print(f" {commit}")
|
||||
|
|
|
@ -6,11 +6,11 @@ Functions for generating symbol maps
|
|||
|
||||
import logging
|
||||
import subprocess
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from comma.database.driver import DatabaseDriver
|
||||
from comma.database.model import PatchData
|
||||
from comma.util import config
|
||||
from comma.util.tracking import get_linux_repo
|
||||
|
||||
|
||||
|
@ -21,7 +21,7 @@ def get_symbols(repo_dir, files):
|
|||
"""
|
||||
Returns a set of symbols for given files
|
||||
files: iterable of files
|
||||
returns: set of symbols generated through ctags
|
||||
returns set of symbols generated through ctags
|
||||
"""
|
||||
command = "ctags -R -x −−c−kinds=f {}".format(
|
||||
" ".join(files) + " | awk '{ if ($2 == \"function\") print $1 }'"
|
||||
|
@ -39,99 +39,111 @@ def get_symbols(repo_dir, files):
|
|||
return set(process.stdout.splitlines())
|
||||
|
||||
|
||||
def map_symbols_to_patch(
|
||||
repo, commits, files, prev_commit="097c1bd5673edaf2a162724636858b71f658fdd2"
|
||||
):
|
||||
class Symbols:
|
||||
"""
|
||||
This function generates and stores symbols generated by each patch
|
||||
repo: git repo object
|
||||
files: hyperV files
|
||||
commits: SHA of all commits in database
|
||||
prev_commit: SHA of start of HyperV patch to track
|
||||
Parent object for symbol operations
|
||||
"""
|
||||
|
||||
LOGGER.info("Mapping symbols to commits")
|
||||
def __init__(self, config, database) -> None:
|
||||
self.config = config
|
||||
self.database = database
|
||||
|
||||
# Preserve initial reference
|
||||
initial_reference = repo.head.reference
|
||||
@cached_property
|
||||
def repo(self):
|
||||
"""
|
||||
Get repo when first accessed
|
||||
"""
|
||||
return get_linux_repo(name="linux-sym", pull=True)
|
||||
|
||||
try:
|
||||
repo.checkout(prev_commit)
|
||||
before_patch_apply = None
|
||||
def get_missing_commits(self, symbol_file):
|
||||
"""Returns a sorted list of commit IDs whose symbols are missing from the given file"""
|
||||
|
||||
# Iterate through commits
|
||||
for commit in commits:
|
||||
# Get symbols before patch is applied
|
||||
if before_patch_apply is None:
|
||||
before_patch_apply = get_symbols(repo.working_tree_dir, files)
|
||||
LOGGER.info("Starting Symbol Checker")
|
||||
self.get_patch_symbols()
|
||||
LOGGER.info("Detecting missing symbols")
|
||||
return self.symbol_checker(symbol_file)
|
||||
|
||||
# Checkout commit
|
||||
repo.checkout(commit)
|
||||
def get_patch_symbols(self):
|
||||
"""
|
||||
This function clones upstream and gets upstream commits
|
||||
"""
|
||||
|
||||
# Get symbols after patch is applied
|
||||
after_patch_apply = get_symbols(repo.working_tree_dir, files)
|
||||
with self.database.get_session() as session:
|
||||
# SQLAlchemy returns tuples which need to be unwrapped
|
||||
self.map_symbols_to_patch(
|
||||
[
|
||||
commit[0]
|
||||
for commit in session.query(PatchData.commitID)
|
||||
.order_by(PatchData.commitTime)
|
||||
.all()
|
||||
],
|
||||
self.repo.get_tracked_paths(self.config.sections),
|
||||
)
|
||||
|
||||
# Compare symbols before and after patch
|
||||
diff_symbols = after_patch_apply - before_patch_apply
|
||||
if diff_symbols:
|
||||
print(f"Commit: {commit} -> {' '.join(diff_symbols)}")
|
||||
def map_symbols_to_patch(
|
||||
self, commits: Iterable[str], paths, prev_commit="097c1bd5673edaf2a162724636858b71f658fdd2"
|
||||
):
|
||||
"""
|
||||
This function generates and stores symbols generated by each patch
|
||||
repo: git repo object
|
||||
files: hyperV files
|
||||
commits: SHA of all commits in database
|
||||
prev_commit: SHA of start of HyperV patch to track
|
||||
"""
|
||||
|
||||
# Save symbols to database
|
||||
with DatabaseDriver.get_session() as session:
|
||||
patch = session.query(PatchData).filter_by(commitID=commit).one()
|
||||
patch.symbols = " ".join(diff_symbols)
|
||||
LOGGER.info("Mapping symbols to commits")
|
||||
|
||||
# Use symbols from current commit to compare to next commit
|
||||
before_patch_apply = after_patch_apply
|
||||
# Preserve initial reference
|
||||
initial_reference = self.repo.head.reference
|
||||
|
||||
finally:
|
||||
# Reset reference
|
||||
repo.checkout(initial_reference)
|
||||
try:
|
||||
self.repo.checkout(prev_commit)
|
||||
before_patch_apply = None
|
||||
|
||||
# Iterate through commits
|
||||
for commit in commits:
|
||||
# Get symbols before patch is applied
|
||||
if before_patch_apply is None:
|
||||
before_patch_apply = get_symbols(self.repo.working_tree_dir, paths)
|
||||
|
||||
def get_hyperv_patch_symbols():
|
||||
"""
|
||||
This function clones upstream and gets upstream commits, hyperV files
|
||||
"""
|
||||
# Checkout commit
|
||||
self.repo.checkout(commit)
|
||||
|
||||
repo = get_linux_repo(name="linux-sym", pull=True)
|
||||
# Get symbols after patch is applied
|
||||
after_patch_apply = get_symbols(self.repo.working_tree_dir, paths)
|
||||
|
||||
with DatabaseDriver.get_session() as session:
|
||||
# SQLAlchemy returns tuples which need to be unwrapped
|
||||
map_symbols_to_patch(
|
||||
repo,
|
||||
[
|
||||
commit[0]
|
||||
for commit in session.query(PatchData.commitID).order_by(PatchData.commitTime).all()
|
||||
],
|
||||
repo.get_tracked_paths(config.sections),
|
||||
)
|
||||
# Compare symbols before and after patch
|
||||
diff_symbols = after_patch_apply - before_patch_apply
|
||||
if diff_symbols:
|
||||
print(f"Commit: {commit} -> {' '.join(diff_symbols)}")
|
||||
|
||||
# Save symbols to database
|
||||
with self.database.get_session() as session:
|
||||
patch = session.query(PatchData).filter_by(commitID=commit).one()
|
||||
patch.symbols = " ".join(diff_symbols)
|
||||
|
||||
def symbol_checker(filepath: Path):
|
||||
"""
|
||||
This function returns missing symbols by comparing database patch symbols with given symbols
|
||||
symbol_file: file containing symbols to run against database
|
||||
return missing_symbols_patch: list of missing symbols from given list
|
||||
"""
|
||||
with open(filepath, "r", encoding="utf-8") as symbol_file:
|
||||
symbols_in_file = {line.strip() for line in symbol_file}
|
||||
# Use symbols from current commit to compare to next commit
|
||||
before_patch_apply = after_patch_apply
|
||||
|
||||
with DatabaseDriver.get_session() as session:
|
||||
return sorted(
|
||||
commitID
|
||||
for commitID, symbols in session.query(PatchData.commitID, PatchData.symbols)
|
||||
.filter(PatchData.symbols != " ")
|
||||
.order_by(PatchData.commitTime)
|
||||
.all()
|
||||
if len(set(symbols.split(" ")) - symbols_in_file) > 0
|
||||
)
|
||||
finally:
|
||||
# Reset reference
|
||||
self.repo.checkout(initial_reference)
|
||||
|
||||
def symbol_checker(self, file_path: Path):
|
||||
"""
|
||||
This function returns missing symbols by comparing database patch symbols with given symbols
|
||||
file_path: file containing symbols to run against database
|
||||
returns sorted list of commits whose symbols are missing from file
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as symbol_file:
|
||||
symbols_in_file = {line.strip() for line in symbol_file}
|
||||
|
||||
def get_missing_commits(symbol_file):
|
||||
"""Returns a sorted list of commit IDs whose symbols are missing from the given file"""
|
||||
|
||||
LOGGER.info("Starting Symbol Checker")
|
||||
get_hyperv_patch_symbols()
|
||||
LOGGER.info("Detecting missing symbols")
|
||||
return symbol_checker(symbol_file)
|
||||
with self.database.get_session() as session:
|
||||
return sorted(
|
||||
commitID
|
||||
for commitID, symbols in session.query(PatchData.commitID, PatchData.symbols)
|
||||
.filter(PatchData.symbols != " ")
|
||||
.order_by(PatchData.commitTime)
|
||||
.all()
|
||||
if len(set(symbols.split(" ")) - symbols_in_file) > 0
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче