зеркало из https://github.com/mozilla/bugbug.git
165 строки
4.6 KiB
Python
165 строки
4.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from urllib.request import urlretrieve
|
|
|
|
import requests
|
|
from redis import Redis
|
|
|
|
from bugbug import bugzilla, get_bugbug_version
|
|
from bugbug.models import load_model
|
|
from bugbug.utils import zstd_decompress
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
LOGGER = logging.getLogger()
|
|
|
|
MODELS_NAMES = ["defectenhancementtask", "component", "regression", "stepstoreproduce"]
|
|
MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
|
|
BASE_URL = "https://index.taskcluster.net/v1/task/project.relman.bugbug.train_{}.latest/artifacts/public"
|
|
DEFAULT_EXPIRATION_TTL = 7 * 24 * 3600 # A week
|
|
|
|
|
|
MODEL_CACHE = {}
|
|
|
|
ALLOW_MISSING_MODELS = bool(int(os.environ.get("BUGBUG_ALLOW_MISSING_MODELS", "0")))
|
|
|
|
|
|
def result_key(model_name, bug_id):
|
|
return f"result_{model_name}_{bug_id}"
|
|
|
|
|
|
def change_time_key(model_name, bug_id):
|
|
return f"bugbug:change_time_{model_name}_{bug_id}"
|
|
|
|
|
|
def get_model(model_name):
|
|
if model_name not in MODEL_CACHE:
|
|
print("Recreating the %r model in cache" % model_name)
|
|
try:
|
|
model = load_model(model_name, MODELS_DIR)
|
|
except FileNotFoundError:
|
|
if ALLOW_MISSING_MODELS:
|
|
print(
|
|
"Missing %r model, skipping because ALLOW_MISSING_MODELS is set"
|
|
% model_name
|
|
)
|
|
return None
|
|
else:
|
|
raise
|
|
|
|
MODEL_CACHE[model_name] = model
|
|
return model
|
|
|
|
return MODEL_CACHE[model_name]
|
|
|
|
|
|
def preload_models():
|
|
for model in MODELS_NAMES:
|
|
get_model(model)
|
|
|
|
|
|
def retrieve_model(name):
|
|
os.makedirs(MODELS_DIR, exist_ok=True)
|
|
|
|
file_name = f"{name}model"
|
|
file_path = os.path.join(MODELS_DIR, file_name)
|
|
|
|
base_model_url = BASE_URL.format(name, f"v{get_bugbug_version()}")
|
|
model_url = f"{base_model_url}/{file_name}.zst"
|
|
LOGGER.info(f"Checking ETAG of {model_url}")
|
|
|
|
r = requests.head(model_url, allow_redirects=True)
|
|
r.raise_for_status()
|
|
new_etag = r.headers["ETag"]
|
|
|
|
try:
|
|
with open(f"{file_path}.etag", "r") as f:
|
|
old_etag = f.read()
|
|
except IOError:
|
|
old_etag = None
|
|
|
|
if old_etag != new_etag:
|
|
LOGGER.info(f"Downloading the model from {model_url}")
|
|
urlretrieve(model_url, f"{file_path}.zst")
|
|
|
|
zstd_decompress(file_path)
|
|
LOGGER.info(f"Written model in {file_path}")
|
|
|
|
with open(f"{file_path}.etag", "w") as f:
|
|
f.write(new_etag)
|
|
else:
|
|
LOGGER.info(f"ETAG for {model_url} is ok")
|
|
|
|
return file_path
|
|
|
|
|
|
def classify_bug(
|
|
model_name, bug_ids, bugzilla_token, expiration=DEFAULT_EXPIRATION_TTL
|
|
):
|
|
# This should be called in a process worker so it should be safe to set
|
|
# the token here
|
|
bug_ids_set = set(map(int, bug_ids))
|
|
bugzilla.set_token(bugzilla_token)
|
|
bugs = bugzilla.get(bug_ids)
|
|
|
|
redis_url = os.environ.get("REDIS_URL", "redis://localhost/0")
|
|
redis = Redis.from_url(redis_url)
|
|
|
|
missing_bugs = bug_ids_set.difference(bugs.keys())
|
|
|
|
for bug_id in missing_bugs:
|
|
redis_key = f"result_{model_name}_{bug_id}"
|
|
|
|
# TODO: Find a better error format
|
|
encoded_data = json.dumps({"available": False})
|
|
|
|
redis.set(redis_key, encoded_data)
|
|
redis.expire(redis_key, expiration)
|
|
|
|
if not bugs:
|
|
return "NOK"
|
|
|
|
model = get_model(model_name)
|
|
|
|
if not model:
|
|
print("Missing model %r, aborting" % model_name)
|
|
return "NOK"
|
|
|
|
model_extra_data = model.get_extra_data()
|
|
|
|
# TODO: Classify could choke on a single bug which could make the whole
|
|
# job to fails. What should we do here?
|
|
probs = model.classify(list(bugs.values()), True)
|
|
indexes = probs.argmax(axis=-1)
|
|
suggestions = model.clf._le.inverse_transform(indexes)
|
|
|
|
probs_list = probs.tolist()
|
|
indexes_list = indexes.tolist()
|
|
suggestions_list = suggestions.tolist()
|
|
|
|
for i, bug_id in enumerate(bugs.keys()):
|
|
data = {
|
|
"prob": probs_list[i],
|
|
"index": indexes_list[i],
|
|
"class": suggestions_list[i],
|
|
"extra_data": model_extra_data,
|
|
}
|
|
|
|
encoded_data = json.dumps(data)
|
|
|
|
redis_key = result_key(model_name, bug_id)
|
|
|
|
redis.set(redis_key, encoded_data)
|
|
redis.expire(redis_key, expiration)
|
|
|
|
# Save the bug last change
|
|
change_key = change_time_key(model_name, bug_id)
|
|
redis.set(change_key, bugs[bug_id]["last_change_time"])
|
|
|
|
return "OK"
|