bugbug/http_service/models.py

165 строки
4.6 KiB
Python

# -*- coding: utf-8 -*-
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import json
import logging
import os
from urllib.request import urlretrieve
import requests
from redis import Redis
from bugbug import bugzilla, get_bugbug_version
from bugbug.models import load_model
from bugbug.utils import zstd_decompress
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger()
MODELS_NAMES = ["defectenhancementtask", "component", "regression", "stepstoreproduce"]
MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
BASE_URL = "https://index.taskcluster.net/v1/task/project.relman.bugbug.train_{}.latest/artifacts/public"
DEFAULT_EXPIRATION_TTL = 7 * 24 * 3600 # A week
MODEL_CACHE = {}
ALLOW_MISSING_MODELS = bool(int(os.environ.get("BUGBUG_ALLOW_MISSING_MODELS", "0")))
def result_key(model_name, bug_id):
return f"result_{model_name}_{bug_id}"
def change_time_key(model_name, bug_id):
return f"bugbug:change_time_{model_name}_{bug_id}"
def get_model(model_name):
if model_name not in MODEL_CACHE:
print("Recreating the %r model in cache" % model_name)
try:
model = load_model(model_name, MODELS_DIR)
except FileNotFoundError:
if ALLOW_MISSING_MODELS:
print(
"Missing %r model, skipping because ALLOW_MISSING_MODELS is set"
% model_name
)
return None
else:
raise
MODEL_CACHE[model_name] = model
return model
return MODEL_CACHE[model_name]
def preload_models():
for model in MODELS_NAMES:
get_model(model)
def retrieve_model(name):
os.makedirs(MODELS_DIR, exist_ok=True)
file_name = f"{name}model"
file_path = os.path.join(MODELS_DIR, file_name)
base_model_url = BASE_URL.format(name, f"v{get_bugbug_version()}")
model_url = f"{base_model_url}/{file_name}.zst"
LOGGER.info(f"Checking ETAG of {model_url}")
r = requests.head(model_url, allow_redirects=True)
r.raise_for_status()
new_etag = r.headers["ETag"]
try:
with open(f"{file_path}.etag", "r") as f:
old_etag = f.read()
except IOError:
old_etag = None
if old_etag != new_etag:
LOGGER.info(f"Downloading the model from {model_url}")
urlretrieve(model_url, f"{file_path}.zst")
zstd_decompress(file_path)
LOGGER.info(f"Written model in {file_path}")
with open(f"{file_path}.etag", "w") as f:
f.write(new_etag)
else:
LOGGER.info(f"ETAG for {model_url} is ok")
return file_path
def classify_bug(
model_name, bug_ids, bugzilla_token, expiration=DEFAULT_EXPIRATION_TTL
):
# This should be called in a process worker so it should be safe to set
# the token here
bug_ids_set = set(map(int, bug_ids))
bugzilla.set_token(bugzilla_token)
bugs = bugzilla.get(bug_ids)
redis_url = os.environ.get("REDIS_URL", "redis://localhost/0")
redis = Redis.from_url(redis_url)
missing_bugs = bug_ids_set.difference(bugs.keys())
for bug_id in missing_bugs:
redis_key = f"result_{model_name}_{bug_id}"
# TODO: Find a better error format
encoded_data = json.dumps({"available": False})
redis.set(redis_key, encoded_data)
redis.expire(redis_key, expiration)
if not bugs:
return "NOK"
model = get_model(model_name)
if not model:
print("Missing model %r, aborting" % model_name)
return "NOK"
model_extra_data = model.get_extra_data()
# TODO: Classify could choke on a single bug which could make the whole
# job to fails. What should we do here?
probs = model.classify(list(bugs.values()), True)
indexes = probs.argmax(axis=-1)
suggestions = model.clf._le.inverse_transform(indexes)
probs_list = probs.tolist()
indexes_list = indexes.tolist()
suggestions_list = suggestions.tolist()
for i, bug_id in enumerate(bugs.keys()):
data = {
"prob": probs_list[i],
"index": indexes_list[i],
"class": suggestions_list[i],
"extra_data": model_extra_data,
}
encoded_data = json.dumps(data)
redis_key = result_key(model_name, bug_id)
redis.set(redis_key, encoded_data)
redis.expire(redis_key, expiration)
# Save the bug last change
change_key = change_time_key(model_name, bug_id)
redis.set(change_key, bugs[bug_id]["last_change_time"])
return "OK"