зеркало из https://github.com/mozilla/bugbug.git
139 строки
4.2 KiB
Python
139 строки
4.2 KiB
Python
import argparse
|
|
import csv
|
|
import json
|
|
import logging
|
|
import sys
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from bugbug.generative_model_tool import create_llm_from_args
|
|
from bugbug.tools.comment_resolver import (
|
|
CodeGeneratorEvaluatorTool,
|
|
FixCommentDB,
|
|
LocalQdrantVectorDB,
|
|
)
|
|
|
|
|
|
def find_fix_in_dataset(revision_id, initial_patch_id, dataset_file):
|
|
with open(dataset_file, "r") as f:
|
|
for line in f:
|
|
data = json.loads(line)
|
|
if data["revision_id"] == int(revision_id) and data[
|
|
"initial_patch_id"
|
|
] == int(initial_patch_id):
|
|
return data["fix_patch_diff"]
|
|
return None
|
|
|
|
|
|
def calculate_metrics(reference_fix, generated_fix):
|
|
reference_tokens = reference_fix.split()
|
|
generated_tokens = generated_fix.split()
|
|
|
|
common_tokens = set(reference_tokens) & set(generated_tokens)
|
|
precision = len(common_tokens) / len(generated_tokens) if generated_tokens else 0
|
|
recall = len(common_tokens) / len(reference_tokens) if reference_tokens else 0
|
|
f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) else 0
|
|
|
|
return {"precision": precision, "recall": recall, "f1": f1}
|
|
|
|
|
|
def compare_fixes(revision_id, initial_patch_id, generated_fix, reference_fix):
|
|
if reference_fix:
|
|
metrics = calculate_metrics(reference_fix, generated_fix)
|
|
return metrics
|
|
else:
|
|
logging.info(
|
|
f"No matching fix found in the dataset for Revision {revision_id} and Patch {initial_patch_id}."
|
|
)
|
|
return None
|
|
|
|
|
|
def conduct_evaluation(input_csv, output_csv, llm_tool):
|
|
with open(input_csv, "r") as infile, open(
|
|
output_csv, mode="w", newline=""
|
|
) as outfile:
|
|
reader = csv.DictReader(infile)
|
|
|
|
fieldnames = reader.fieldnames + [
|
|
"Reference Fix",
|
|
"Precision",
|
|
"Recall",
|
|
"F1",
|
|
"Qualitative Feedback",
|
|
]
|
|
writer = csv.DictWriter(outfile, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
|
|
for row in reader:
|
|
revision_id = row["Revision ID"]
|
|
initial_patch_id = row["Patch ID"]
|
|
generated_fix = row["Generated Fix"]
|
|
comment = row["Comment"]
|
|
relevant_diff = row["Relevant Diff"]
|
|
|
|
reference_fix = find_fix_in_dataset(
|
|
revision_id=revision_id,
|
|
initial_patch_id=initial_patch_id,
|
|
dataset_file="data/fixed_comments.json",
|
|
)
|
|
|
|
metrics = compare_fixes(
|
|
revision_id=revision_id,
|
|
initial_patch_id=initial_patch_id,
|
|
generated_fix=generated_fix,
|
|
reference_fix=reference_fix,
|
|
)
|
|
|
|
qualitative_feedback = llm_tool.generate_fix(
|
|
comment, relevant_diff, generated_fix
|
|
)
|
|
|
|
if metrics is not None:
|
|
writer.writerow(
|
|
{
|
|
**row,
|
|
"Reference Fix": reference_fix,
|
|
"Precision": metrics["precision"],
|
|
"Recall": metrics["recall"],
|
|
"F1": metrics["f1"],
|
|
"Qualitative Feedback": qualitative_feedback,
|
|
}
|
|
)
|
|
|
|
|
|
def run(args) -> None:
|
|
load_dotenv()
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
db = FixCommentDB(LocalQdrantVectorDB(collection_name="fix_comments"))
|
|
llm = create_llm_from_args(args)
|
|
llm_tool = CodeGeneratorEvaluatorTool(llm=llm, db=db)
|
|
|
|
input_csv = args.input_csv
|
|
output_csv = args.output_csv
|
|
|
|
conduct_evaluation(input_csv, output_csv, llm_tool)
|
|
|
|
|
|
def parse_args(args):
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--llm", help="LLM", choices=["openai"], default="openai")
|
|
parser.add_argument(
|
|
"--input-csv",
|
|
type=str,
|
|
default="code_generations.csv",
|
|
help="Input CSV file from the generation script.",
|
|
)
|
|
parser.add_argument(
|
|
"--output-csv",
|
|
type=str,
|
|
default="evaluated_code_generations.csv",
|
|
help="Output CSV file for results.",
|
|
)
|
|
return parser.parse_args(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args(sys.argv[1:])
|
|
run(args)
|