bugbug/http_service/app.py

469 строки
14 KiB
Python

# -*- coding: utf-8 -*-
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import json
import logging
import os
import uuid
from datetime import datetime, timedelta
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from apispec_webframeworks.flask import FlaskPlugin
from cerberus import Validator
from flask import Flask, jsonify, render_template, request
from flask_cors import cross_origin
from marshmallow import Schema, fields
from redis import Redis
from rq import Queue
from rq.exceptions import NoSuchJobError
from rq.job import Job
from bugbug import get_bugbug_version
from .models import MODELS_NAMES, change_time_key, classify_bug, result_key
from .utils import get_bugzilla_http_client
API_TOKEN = "X-Api-Key"
API_DESCRIPTION = """
This is the documentation for the BubBug http service, the platform for Bugzilla Machine Learning projects.
# Introduction
This service can be used to classify a given bug using a pre-trained model.
You can classify a single bug or a batch of bugs.
The classification happens in the background so you need to call back the service for getting the results.
"""
spec = APISpec(
title="Bugbug",
version=get_bugbug_version(),
openapi_version="3.0.2",
info=dict(description=API_DESCRIPTION),
plugins=[FlaskPlugin(), MarshmallowPlugin()],
security=[{"api_key": []}],
)
application = Flask(__name__)
redis_url = os.environ.get("REDIS_URL", "redis://localhost/0")
redis_conn = Redis.from_url(redis_url)
JOB_TIMEOUT = 1800 # 30 minutes in seconds
q = Queue(
connection=redis_conn, default_timeout=JOB_TIMEOUT
) # no args implies the default queue
VALIDATOR = Validator()
BUGZILLA_TOKEN = os.environ.get("BUGBUG_BUGZILLA_TOKEN")
# Keep an HTTP client around for persistent connections
BUGBUG_HTTP_CLIENT, BUGZILLA_API_URL = get_bugzilla_http_client()
logging.basicConfig(level=logging.DEBUG)
LOGGER = logging.getLogger()
class BugPrediction(Schema):
prob = fields.List(fields.Float())
index = fields.Integer()
suggestion = fields.Str()
extra_data = fields.Dict()
class BugPredictionNotAvailableYet(Schema):
ready = fields.Boolean(enum=[False])
class ModelName(Schema):
model_name = fields.Str(enum=MODELS_NAMES, example="component")
class UnauthorizedError(Schema):
message = fields.Str(default="Error, missing X-API-KEY")
spec.components.schema(BugPrediction.__name__, schema=BugPrediction)
spec.components.schema(
BugPredictionNotAvailableYet.__name__, schema=BugPredictionNotAvailableYet
)
spec.components.schema(ModelName.__name__, schema=ModelName)
spec.components.schema(UnauthorizedError.__name__, schema=UnauthorizedError)
api_key_scheme = {"type": "apiKey", "in": "header", "name": "X-API-Key"}
spec.components.security_scheme("api_key", api_key_scheme)
def get_job_id():
return uuid.uuid4().hex
def get_mapping_key(model_name, bug_id):
return f"bugbug:mapping_{model_name}_{bug_id}"
def schedule_bug_classification(model_name, bug_ids):
""" Schedule the classification of a bug_id list
"""
job_id = get_job_id()
# Set the mapping before queuing to avoid some race conditions
job_id_mapping = {get_mapping_key(model_name, bug_id): job_id for bug_id in bug_ids}
redis_conn.mset(job_id_mapping)
q.enqueue(classify_bug, model_name, bug_ids, BUGZILLA_TOKEN, job_id=job_id)
def is_running(model_name, bug_id):
# Check if there is a job
mapping_key = get_mapping_key(model_name, bug_id)
job_id = redis_conn.get(mapping_key)
if not job_id:
LOGGER.debug("No job ID mapping %s, False", job_id)
return False
try:
job = Job.fetch(job_id.decode("utf-8"), connection=redis_conn)
except NoSuchJobError:
LOGGER.debug("No job in DB for %s, False", job_id)
# The job might have expired from redis
return False
job_status = job.get_status()
if job_status == "started":
LOGGER.debug("Job running %s, True", job_id)
return True
# Enforce job timeout as RQ doesn't seems to do it https://github.com/rq/rq/issues/758
timeout_datetime = job.enqueued_at + timedelta(seconds=job.timeout)
utcnow = datetime.utcnow()
if timeout_datetime < utcnow:
# Remove the timeouted job so it will be requeued
job.cancel()
job.cleanup()
LOGGER.debug("Job timeout %s, False", job_id)
return False
LOGGER.debug("Job status %s, False", job_status)
return False
def get_bugs_last_change_time(bug_ids):
query = {
"id": ",".join(map(str, bug_ids)),
"include_fields": ["last_change_time", "id"],
}
header = {"X-Bugzilla-API-Key": "", "User-Agent": "bugbug"}
response = BUGBUG_HTTP_CLIENT.get(
BUGZILLA_API_URL, params=query, headers=header, verify=True, timeout=30
)
response.raise_for_status()
raw_bugs = response.json()
bugs = {}
for bug in raw_bugs["bugs"]:
bugs[bug["id"]] = bug["last_change_time"]
return bugs
def is_prediction_invalidated(model_name, bug_id, change_time):
# First get the saved change time
change_key = change_time_key(model_name, bug_id)
saved_change_time = redis_conn.get(change_key)
# If we have no last changed time, the bug was not classified yet or the bug was classified by an old worker
if not saved_change_time:
# We can have a result without a cache time
if redis_conn.exists(result_key(model_name, bug_id)):
return True
return False
return saved_change_time.decode("utf-8") != change_time
def clean_prediction_cache(model_name, bug_id):
# If the bug was modified since last time we classified it, clear the cache to avoid stale answer
LOGGER.debug("Cleaning results for bug id %s and model %s", bug_id, model_name)
redis_conn.delete(result_key(model_name, bug_id))
redis_conn.delete(change_time_key(model_name, bug_id))
def get_bug_classification(model_name, bug_id):
redis_key = f"result_{model_name}_{bug_id}"
result = redis_conn.get(redis_key)
if result:
return json.loads(result)
return None
@application.route("/<model_name>/predict/<int:bug_id>")
@cross_origin()
def model_prediction(model_name, bug_id):
"""
---
get:
description: Classify a single bug using given model, answer either 200 if the bug is processed or 202 if the bug is being processed
summary: Classify a single bug
parameters:
- name: model_name
in: path
schema: ModelName
- name: bug_id
in: path
schema:
type: integer
example: 123456
responses:
200:
description: A single bug prediction
content:
application/json:
schema: BugPrediction
202:
description: A temporary answer for the bug being processed
content:
application/json:
schema:
type: object
properties:
ready:
type: boolean
enum: [False]
401:
description: API key is missing
content:
application/json:
schema: UnauthorizedError
"""
headers = request.headers
redis_conn.ping()
auth = headers.get(API_TOKEN)
if not auth:
return jsonify(UnauthorizedError().dump({}).data), 401
else:
LOGGER.info("Request with API TOKEN %r", auth)
# Get the latest change from Bugzilla for the bug
bug = get_bugs_last_change_time([bug_id], auth)
if is_prediction_invalidated(model_name, bug_id, bug[bug_id]):
clean_prediction_cache(model_name, bug_id)
status_code = 200
data = get_bug_classification(model_name, bug_id)
if not data:
if not is_running(model_name, bug_id):
schedule_bug_classification(model_name, [bug_id])
status_code = 202
data = {"ready": False}
return jsonify(**data), status_code
@application.route("/<model_name>/predict/batch", methods=["POST"])
@cross_origin()
def batch_prediction(model_name):
"""
---
post:
description: >
Post a batch of bug ids to classify, answer either 200 if all bugs are
processed or 202 if at least one bug is not processed.
<br/><br/>
Starts by sending a batch of bugs ids like this:<br/>
```
{"bugs": [123, 456]}
```<br/><br>
You will likely get a 202 answer that indicates that no result is
available yet for any of the bug id you provided with the following
body:<br/>
```
{"bugs": {"123": {ready: False}, "456": {ready: False}}}
```<br/><br/>
Call back the same endpoint with the same bug ids a bit later, and you
will get the results.<br/><br/>
You might get the following output if some bugs are not available:
<br/>
```
{"bugs": {"123": {"available": False}}}
```<br/><br/>
And you will get the following output once the bugs are available:
<br/>
```
{"bugs": {"456": {"extra_data": {}, "index": 0, "prob": [0], "suggestion": ""}}}
```<br/><br/>
Please be aware that each bug could be in a different state, so the
following output, where a bug is returned and another one is still
being processed, is valid:
<br/>
```
{"bugs": {"123": {"available": False}, "456": {"extra_data": {}, "index": 0, "prob": [0], "suggestion": ""}}}
```
summary: Classify a batch of bugs
parameters:
- name: model_name
in: path
schema: ModelName
requestBody:
description: The list of bugs to classify
content:
application/json:
schema:
type: object
properties:
bugs:
type: array
items:
type: integer
examples:
cat:
summary: An example of payload
value:
bugs:
[123456, 789012]
responses:
200:
description: A list of results
content:
application/json:
schema:
type: object
additionalProperties: true
example:
bugs:
123456:
extra_data: {}
index: 0
prob: [0]
suggestion: string
789012:
extra_data: {}
index: 0
prob: [0]
suggestion: string
202:
description: A temporary answer for bugs being processed
content:
application/json:
schema:
type: object
items:
type: object
properties:
ready:
type: boolean
enum: [False]
example:
bugs:
123456:
extra_data: {}
index: 0
prob: [0]
suggestion: string
789012: {ready: False}
401:
description: API key is missing
content:
application/json:
schema: UnauthorizedError
"""
headers = request.headers
auth = headers.get(API_TOKEN)
if not auth:
return jsonify(UnauthorizedError().dump({}).data), 401
else:
LOGGER.info("Request with API TOKEN %r", auth)
# TODO Check is JSON is valid and validate against a request schema
batch_body = json.loads(request.data)
# Validate
schema = {
"bugs": {
"type": "list",
"minlength": 1,
"maxlength": 1000,
"schema": {"type": "integer"},
}
}
validator = Validator()
if not validator.validate(batch_body, schema):
return jsonify({"errors": validator.errors}), 400
bugs = batch_body["bugs"]
status_code = 200
data = {}
missing_bugs = []
bug_change_dates = get_bugs_last_change_time(bugs)
for bug_id in bugs:
change_time = bug_change_dates.get(int(bug_id), None)
# Change time could be None if it's a security bug
if change_time and is_prediction_invalidated(model_name, bug_id, change_time):
clean_prediction_cache(model_name, bug_id)
data[str(bug_id)] = get_bug_classification(model_name, bug_id)
if not data[str(bug_id)]:
if not is_running(model_name, bug_id):
missing_bugs.append(bug_id)
status_code = 202
data[str(bug_id)] = {"ready": False}
if missing_bugs:
# TODO: We should probably schedule chunks of bugs to avoid jobs that
# are running for too long and reduce pressure on bugzilla, it mights
# not like getting 1 million bug at a time
schedule_bug_classification(model_name, missing_bugs)
return jsonify({"bugs": data}), status_code
@application.route("/swagger")
@cross_origin()
def swagger():
for name, rule in application.view_functions.items():
# Ignore static endpoint as it isn't documented with OpenAPI
if name == "static":
continue
spec.path(view=rule)
return jsonify(spec.to_dict())
@application.route("/doc")
def doc():
return render_template("doc.html")