diff --git a/scripts/generate_sheet.py b/scripts/generate_sheet.py index 88d4c3ed..bf04c064 100644 --- a/scripts/generate_sheet.py +++ b/scripts/generate_sheet.py @@ -5,6 +5,8 @@ import csv import os from datetime import datetime, timedelta +import numpy as np + from bugbug import bugzilla from bugbug.models import get_model_class @@ -31,10 +33,17 @@ def generate_sheet(model_name, token, days, threshold): for bug in bugs.values(): p = model.classify(bug, probabilities=True) + probability = p[0] + if len(probability) > 2: + index = np.argmax(probability) + prediction = model.class_names[index] + else: + prediction = "y" if probability[1] >= threshold else "n" + rows.append( [ f'https://bugzilla.mozilla.org/show_bug.cgi?id={bug["id"]}', - "y" if p[0][1] >= threshold else "n", + prediction, "", bug["summary"], ]