End to end duplicate model script (#1509)

This commit is contained in:
Ayush Shridhar 2020-05-20 21:04:49 +10:00 коммит произвёл GitHub
Родитель 49c7ee369f
Коммит 8b6e7d266f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 172 добавлений и 56 удалений

Просмотреть файл

@ -21,6 +21,7 @@ from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
from bugbug import bugzilla, feature_cleanup
from bugbug.models.duplicate import DuplicateModel
from bugbug.utils import download_check_etag, zstd_decompress
OPT_MSG_MISSING = (
@ -80,7 +81,9 @@ def download_and_load_similarity_model(model_name):
class BaseSimilarity(abc.ABC):
def __init__(self, cleanup_urls=True, nltk_tokenizer=False):
def __init__(
self, cleanup_urls=True, nltk_tokenizer=False, confidence_threshold=0.8
):
self.cleanup_functions = [
feature_cleanup.responses(),
feature_cleanup.hex(),
@ -93,6 +96,7 @@ class BaseSimilarity(abc.ABC):
self.cleanup_functions.append(feature_cleanup.url())
self.nltk_tokenizer = nltk_tokenizer
self.confidence_threshold = confidence_threshold
def get_text(self, bug, all_comments=False):
if all_comments:
@ -130,7 +134,7 @@ class BaseSimilarity(abc.ABC):
return " ".join(word for word in text)
return text
def evaluation(self):
def evaluation(self, end_to_end=False):
# A map from bug ID to its duplicate IDs
duplicates = defaultdict(set)
all_ids = set(
@ -140,6 +144,9 @@ class BaseSimilarity(abc.ABC):
and "dupeme" not in bug["keywords"]
)
if end_to_end:
duplicatemodel = DuplicateModel.load("duplicatemodel")
for bug in bugzilla.get_bugs():
dupes = [entry for entry in bug["duplicates"] if entry in all_ids]
if bug["dupe_of"] in all_ids:
@ -169,6 +176,17 @@ class BaseSimilarity(abc.ABC):
num_hits = 0
queries += 1
similar_bugs = self.get_similar_bugs(bug)[:10]
if end_to_end:
sim_bugs = [
bug for bug in bugzilla.get_bugs() if bug["id"] in similar_bugs
]
bug_couples = [(bug, sim_bugs[bug_id]) for bug_id in sim_bugs]
probs = duplicatemodel.classify(bug_couples, probabilities=True)
similar_bugs = [
similar_bugs[idx]
for idx, prob in enumerate(probs)
if prob[1] > self.confidence_threshold
]
# Recall
for idx, item in enumerate(duplicates[bug["id"]]):
@ -226,8 +244,14 @@ class BaseSimilarity(abc.ABC):
class LSISimilarity(BaseSimilarity):
def __init__(self, cleanup_urls=True, nltk_tokenizer=False):
super().__init__(cleanup_urls=cleanup_urls, nltk_tokenizer=nltk_tokenizer)
def __init__(
self, cleanup_urls=True, nltk_tokenizer=False, confidence_threshold=0.8
):
super().__init__(
cleanup_urls=cleanup_urls,
nltk_tokenizer=nltk_tokenizer,
confidence_threshold=confidence_threshold,
)
self.corpus = []
for bug in bugzilla.get_bugs():
@ -286,8 +310,13 @@ class NeighborsSimilarity(BaseSimilarity):
vectorizer=TfidfVectorizer(),
cleanup_urls=True,
nltk_tokenizer=False,
confidence_threshold=0.8,
):
super().__init__(cleanup_urls=cleanup_urls, nltk_tokenizer=nltk_tokenizer)
super().__init__(
cleanup_urls=cleanup_urls,
nltk_tokenizer=nltk_tokenizer,
confidence_threshold=confidence_threshold,
)
self.vectorizer = vectorizer
self.similarity_calculator = NearestNeighbors(n_neighbors=k)
text = []
@ -314,8 +343,18 @@ class NeighborsSimilarity(BaseSimilarity):
class Word2VecSimilarityBase(BaseSimilarity):
def __init__(self, cut_off=0.2, cleanup_urls=True, nltk_tokenizer=False):
super().__init__(cleanup_urls=cleanup_urls, nltk_tokenizer=nltk_tokenizer)
def __init__(
self,
cut_off=0.2,
cleanup_urls=True,
nltk_tokenizer=False,
confidence_threshold=0.8,
):
super().__init__(
cleanup_urls=cleanup_urls,
nltk_tokenizer=nltk_tokenizer,
confidence_threshold=confidence_threshold,
)
self.corpus = []
self.bug_ids = []
self.cut_off = cut_off
@ -333,8 +372,18 @@ class Word2VecSimilarityBase(BaseSimilarity):
class Word2VecWmdSimilarity(Word2VecSimilarityBase):
def __init__(self, cut_off=0.2, cleanup_urls=True, nltk_tokenizer=False):
super().__init__(cleanup_urls=cleanup_urls, nltk_tokenizer=nltk_tokenizer)
def __init__(
self,
cut_off=0.2,
cleanup_urls=True,
nltk_tokenizer=False,
confidence_threshold=0.8,
):
super().__init__(
cleanup_urls=cleanup_urls,
nltk_tokenizer=nltk_tokenizer,
confidence_threshold=confidence_threshold,
)
# word2vec.wmdistance calculates only the euclidean distance. To get the cosine distance,
# we're using the function with a few subtle changes. We compute the cosine distances
@ -466,8 +515,18 @@ class Word2VecWmdSimilarity(Word2VecSimilarityBase):
class Word2VecWmdRelaxSimilarity(Word2VecSimilarityBase):
def __init__(self, cut_off=0.2, cleanup_urls=True, nltk_tokenizer=False):
super().__init__(cleanup_urls=cleanup_urls, nltk_tokenizer=nltk_tokenizer)
def __init__(
self,
cut_off=0.2,
cleanup_urls=True,
nltk_tokenizer=False,
confidence_threshold=0.8,
):
super().__init__(
cleanup_urls=cleanup_urls,
nltk_tokenizer=nltk_tokenizer,
confidence_threshold=confidence_threshold,
)
self.dictionary = Dictionary(self.corpus)
self.tfidf = TfidfModel(dictionary=self.dictionary)
@ -564,8 +623,18 @@ class Word2VecWmdRelaxSimilarity(Word2VecSimilarityBase):
class Word2VecSoftCosSimilarity(Word2VecSimilarityBase):
def __init__(self, cut_off=0.2, cleanup_urls=True, nltk_tokenizer=False):
super().__init__(cleanup_urls=cleanup_urls, nltk_tokenizer=nltk_tokenizer)
def __init__(
self,
cut_off=0.2,
cleanup_urls=True,
nltk_tokenizer=False,
confidence_threshold=0.8,
):
super().__init__(
cleanup_urls=cleanup_urls,
nltk_tokenizer=nltk_tokenizer,
confidence_threshold=confidence_threshold,
)
terms_idx = WordEmbeddingSimilarityIndex(self.w2vmodel.wv)
self.dictionary = Dictionary(self.corpus)
@ -592,8 +661,14 @@ class Word2VecSoftCosSimilarity(Word2VecSimilarityBase):
class BM25Similarity(BaseSimilarity):
def __init__(self, cleanup_urls=True, nltk_tokenizer=False):
super().__init__(cleanup_urls=cleanup_urls, nltk_tokenizer=nltk_tokenizer)
def __init__(
self, cleanup_urls=True, nltk_tokenizer=False, confidence_threshold=0.8
):
super().__init__(
cleanup_urls=cleanup_urls,
nltk_tokenizer=nltk_tokenizer,
confidence_threshold=confidence_threshold,
)
self.corpus = []
self.bug_ids = []
@ -621,8 +696,14 @@ class BM25Similarity(BaseSimilarity):
class LDASimilarity(BaseSimilarity):
def __init__(self, cleanup_urls=True, nltk_tokenizer=False):
super().__init__(cleanup_urls=cleanup_urls, nltk_tokenizer=nltk_tokenizer)
def __init__(
self, cleanup_urls=True, nltk_tokenizer=False, confidence_threshold=0.8
):
super().__init__(
cleanup_urls=cleanup_urls,
nltk_tokenizer=nltk_tokenizer,
confidence_threshold=confidence_threshold,
)
self.corpus = []
self.bug_ids = []
for bug in bugzilla.get_bugs():
@ -666,8 +747,14 @@ class LDASimilarity(BaseSimilarity):
class ElasticSearchSimilarity(BaseSimilarity):
def __init__(self, cleanup_urls=True, nltk_tokenizer=False):
super().__init__(cleanup_urls=cleanup_urls, nltk_tokenizer=nltk_tokenizer)
def __init__(
self, cleanup_urls=True, nltk_tokenizer=False, confidence_threshold=0.8
):
super().__init__(
cleanup_urls=cleanup_urls,
nltk_tokenizer=nltk_tokenizer,
confidence_threshold=confidence_threshold,
)
self.elastic_search = Elasticsearch()
assert (
self.elastic_search.ping()

Просмотреть файл

@ -1,53 +1,82 @@
# -*- coding: utf-8 -*-
import argparse
import csv
import itertools
import json
import sys
from datetime import datetime, timedelta
from itertools import combinations
from logging import INFO, basicConfig, getLogger
from bugbug import bugzilla
from bugbug import bugzilla, similarity
from bugbug.models.duplicate import DuplicateModel
m = DuplicateModel.load("duplicatemodel")
basicConfig(level=INFO)
logger = getLogger(__name__)
REPORTERS_TO_IGNORE = {"intermittent-bug-filer@mozilla.bugs", "wptsync@mozilla.bugs"}
try:
with open("duplicate_test_bugs.json", "r") as f:
test_bugs = json.load(f)
except FileNotFoundError:
test_bug_ids = bugzilla.get_ids_between(
datetime.now() - timedelta(days=21), datetime.now()
def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument("--similaritymodel", default=None)
return parser.parse_args(args)
def main(args):
similarity_model = (
similarity.download_and_load_similarity_model(args.similaritymodel)
if args.similaritymodel
else None
)
test_bugs = bugzilla.get(test_bug_ids)
test_bugs = [
bug for bug in test_bugs.values() if not bug["creator"] in REPORTERS_TO_IGNORE
]
with open("duplicate_test_bugs.json", "w") as f:
json.dump(test_bugs, f)
duplicate_model = DuplicateModel.load("duplicatemodel")
try:
with open("duplicate_test_bugs.json", "r") as f:
test_bugs = json.load(f)
except FileNotFoundError:
test_bug_ids = bugzilla.get_ids_between(
datetime.now() - timedelta(days=21), datetime.now()
)
test_bugs = bugzilla.get(test_bug_ids)
test_bugs = [
bug
for bug in test_bugs.values()
if not bug["creator"] in REPORTERS_TO_IGNORE
]
with open("duplicate_test_bugs.json", "w") as f:
json.dump(test_bugs, f)
bug_tuples = list(itertools.combinations(test_bugs, 2))
with open("duplicate_predictions.csv", "w") as csvfile:
spamwriter = csv.writer(csvfile)
# Warning: Classifying all the test bugs takes a while
probs = m.classify(bug_tuples, probabilities=True)
with open("duplicate_predictions.csv", "w") as csvfile:
spamwriter = csv.writer(csvfile)
spamwriter.writerow(
["bug 1 ID", "bug 1 summary", "bug 2 ID", "bug 2 summary", "prediction"]
)
for bug_tuple, prob in zip(bug_tuples, probs):
if prob[1] > 0.8:
spamwriter.writerow(
[
f'https://bugzilla.mozilla.org/show_bug.cgi?id={bug_tuple[0]["id"]}',
bug_tuple[0]["summary"],
f'https://bugzilla.mozilla.org/show_bug.cgi?id={bug_tuple[1]["id"]}',
bug_tuple[1]["summary"],
prob[1],
spamwriter.writerow(
["bug 1 ID", "bug 1 summary", "bug 2 ID", "bug 2 summary", "prediction"]
)
if similarity_model:
bug_tuples = []
for test_bug in test_bugs:
similar_bug_ids = similarity_model.get_similar_bugs(test_bug)
similar_bugs = bugzilla.get(similar_bug_ids)
bug_tuples += [
(test_bug, similar_bug) for similar_bug in similar_bugs.values()
]
)
else:
bug_tuples = combinations(test_bugs, 2)
probs = duplicate_model.classify(bug_tuples, probabilities=True)
for bug_tuple, prob in zip(bug_tuples, probs):
if prob[1] > similarity_model.confidence_threshold:
spamwriter.writerow(
[
f'https://bugzilla.mozilla.org/show_bug.cgi?id={bug_tuple[0]["id"]}',
bug_tuple[0]["summary"],
f'https://bugzilla.mozilla.org/show_bug.cgi?id={bug_tuple[1]["id"]}',
bug_tuple[1]["summary"],
prob[1],
]
)
if __name__ == "__main__":
main(parse_args(sys.argv[1:]))