bugbug/scripts/generate_sheet.py

89 строки
2.4 KiB
Python

# -*- coding: utf-8 -*-
import argparse
import csv
import os
from datetime import datetime, timedelta
from logging import INFO, basicConfig, getLogger
import numpy as np
from bugbug import bugzilla
from bugbug.models import get_model_class
basicConfig(level=INFO)
logger = getLogger(__name__)
def generate_sheet(model_name: str, token: str, days: int, threshold: float) -> None:
model_file_name = f"{model_name}model"
assert os.path.exists(
model_file_name
), f"{model_file_name} does not exist. Train the model with trainer.py first."
model_class = get_model_class(model_name)
model = model_class.load(model_file_name)
bugzilla.set_token(token)
bug_ids = bugzilla.get_ids_between(datetime.utcnow() - timedelta(days))
bugs = bugzilla.get(bug_ids)
logger.info("Classifying %d bugs...", len(bugs))
rows = [["Bug", f"{model_name}(model)", model_name, "Title"]]
for bug in bugs.values():
p = model.classify(bug, probabilities=True)
probability = p[0]
if len(probability) > 2:
index = np.argmax(probability)
prediction = model.class_names[index]
else:
prediction = "y" if probability[1] >= threshold else "n"
rows.append(
[
f'https://bugzilla.mozilla.org/show_bug.cgi?id={bug["id"]}',
prediction,
"",
bug["summary"],
]
)
os.makedirs("sheets", exist_ok=True)
with open(
os.path.join(
"sheets",
f'{model_name}-{datetime.utcnow().strftime("%Y-%m-%d")}-labels.csv',
),
"w",
) as f:
writer = csv.writer(f)
writer.writerows(rows)
def main() -> None:
description = "Perform evaluation on bugs from specified days back on the specified model and generate a csv file "
parser = argparse.ArgumentParser(description=description)
parser.add_argument("model", help="Which model to generate a csv for.")
parser.add_argument("token", help="Bugzilla token")
parser.add_argument(
"days",
type=int,
default=7,
help="No. of days back from which bugs will be evaluated",
)
parser.add_argument(
"threshold", type=float, default=0.7, help="Confidence threshold for the model"
)
args = parser.parse_args()
generate_sheet(args.model, args.token, args.days, args.threshold)
if __name__ == "__main__":
main()