diff --git a/comma/util/symbols.py b/comma/util/symbols.py index 669333b..0d006b4 100755 --- a/comma/util/symbols.py +++ b/comma/util/symbols.py @@ -47,34 +47,44 @@ def map_symbols_to_patch( commits: SHA of all commits in database prev_commit: SHA of start of HyperV patch to track """ - repo.head.reference = repo.commit(prev_commit) - repo.head.reset(index=True, working_tree=True) - before_patch_apply = None - # Iterate through commits - for commit in commits: - # Get symbols before patch is applied - if before_patch_apply is None: - before_patch_apply = set(get_symbols(repo.working_tree_dir, files)) + # Preserve initial reference + initial_reference = repo.head.reference - # Checkout commit - repo.head.reference = repo.commit(commit) + try: + repo.head.reference = repo.commit(prev_commit) repo.head.reset(index=True, working_tree=True) + before_patch_apply = None - # Get symbols after patch is applied - after_patch_apply = set(get_symbols(repo.working_tree_dir, files)) + # Iterate through commits + for commit in commits: + # Get symbols before patch is applied + if before_patch_apply is None: + before_patch_apply = set(get_symbols(repo.working_tree_dir, files)) - # Compare symbols before and after patch - diff_symbols = after_patch_apply - before_patch_apply - print(f"Commit: {commit} -> {''.join(diff_symbols)}") + # Checkout commit + repo.head.reference = repo.commit(commit) + repo.head.reset(index=True, working_tree=True) - # Save symbols to database - with DatabaseDriver.get_session() as session: - patch = session.query(PatchData).filter_by(commitID=commit).one() - patch.symbols = " ".join(diff_symbols) + # Get symbols after patch is applied + after_patch_apply = set(get_symbols(repo.working_tree_dir, files)) - # Use symbols from current commit to compare to next commit - before_patch_apply = after_patch_apply + # Compare symbols before and after patch + diff_symbols = after_patch_apply - before_patch_apply + print(f"Commit: {commit} -> {''.join(diff_symbols)}") + + # Save symbols to database + with DatabaseDriver.get_session() as session: + patch = session.query(PatchData).filter_by(commitID=commit).one() + patch.symbols = " ".join(diff_symbols) + + # Use symbols from current commit to compare to next commit + before_patch_apply = after_patch_apply + + finally: + # Reset reference + repo.head.reference = initial_reference + repo.head.reset(index=True, working_tree=True) def get_hyperv_patch_symbols():