Add a central place where the models are defined (#398)

* Add a central place where the models are defined

Also add some helpers to load a model.

* Add missing tensorflow dependency in extra-nn-requirements.txt
This commit is contained in:
Boris Feld 2019-05-16 15:34:38 +02:00 коммит произвёл Marco
Родитель 78e6a8175a
Коммит 0a5e37439d
10 изменённых файлов: 104 добавлений и 129 удалений

Просмотреть файл

@ -90,6 +90,8 @@ tasks:
cd bugbug &&
git checkout ${head_rev} &&
pip install -r requirements.txt &&
pip install -r extra-nlp-requirements.txt &&
pip install -r extra-nn-requirements.txt &&
pip install -r test-requirements.txt &&
python -m pytest tests/test_*.py"
metadata:

Просмотреть файл

@ -1 +1,5 @@
# -*- coding: utf-8 -*-
import logging
logging.basicConfig(level=logging.INFO)

Просмотреть файл

@ -1 +1,57 @@
# -*- coding: utf-8 -*-
import importlib
import logging
import os
LOGGER = logging.getLogger()
MODELS = {
"assignee": "bugbug.models.assignee.AssigneeModel",
"backout": "bugbug.models.backout.BackoutModel",
"bug": "bugbug.model.BugModel",
"bugtype": "bugbug.models.bugtype.BugTypeModel",
"component": "bugbug.models.component.ComponentModel",
"component_nn": "bugbug.models.component_nn.ComponentNNModel",
"defect": "bugbug.models.defect.DefectModel",
"defectenhancementtask": "bugbug.models.defect_enhancement_task.DefectEnhancementTaskModel",
"devdocneeded": "bugbug.models.devdocneeded.DevDocNeededModel",
"qaneeded": "bugbug.models.qaneeded.QANeededModel",
"regression": "bugbug.models.regression.RegressionModel",
"stepstoreproduce": "bugbug.models.stepstoreproduce.StepsToReproduceModel",
"tracking": "bugbug.models.tracking.TrackingModel",
"uplift": "bugbug.models.uplift.UpliftModel",
}
def load_model_class(full_qualified_class_name):
""" Load the class dynamically in order to speed up the boot and allow for
dynamic optional dependencies to be declared and check at import time
"""
module_name, class_name = full_qualified_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
def get_model_class(model_name):
if model_name not in MODELS:
err_msg = f"Invalid name {model_name}, not in {list(MODELS.keys())}"
raise ValueError(err_msg)
full_qualified_class_name = MODELS[model_name]
return load_model_class(full_qualified_class_name)
def load_model(model_name, model_dir=None):
model_class = get_model_class(model_name)
if model_dir is None:
model_dir = "."
model_file_path = os.path.join(model_dir, f"{model_name}model")
LOGGER.info(f"Lookup model in {model_file_path}")
model = model_class.load(model_file_path)
return model

Просмотреть файл

@ -12,11 +12,9 @@ OPT_MSG_MISSING = (
)
try:
from keras.preprocessing.sequence import pad_sequences
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
HAS_OPTIONAL_DEPENDENCIES = True
except ImportError:
raise ImportError(OPT_MSG_MISSING)

Просмотреть файл

@ -9,6 +9,7 @@ import os
import random
from bugbug import bugzilla
from bugbug.models import load_model
parser = argparse.ArgumentParser()
parser.add_argument(
@ -20,13 +21,9 @@ parser.add_argument(
args = parser.parse_args()
if args.goal == "str":
from bugbug.models.bug import BugModel
model = BugModel.load("bugmodel")
model = load_model("bug")
elif args.goal == "regressionrange":
from bugbug.models.regression import RegressionModel
model = RegressionModel.load("regressionmodel")
model = load_model("regression")
file_path = os.path.join("bugbug", "labels", f"{args.goal}.csv")

Просмотреть файл

@ -1 +1,2 @@
keras==2.2.4
tensorflow==1.13.1

56
run.py
Просмотреть файл

@ -12,6 +12,7 @@ import numpy as np
from bugbug import repository # noqa
from bugbug import bugzilla, db
from bugbug.models import get_model_class
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@ -64,59 +65,18 @@ if __name__ == "__main__":
args.goal, "" if args.classifier == "default" else args.classifier
)
if args.goal == "defect":
from bugbug.models.defect import DefectModel
model_class_name = args.goal
model_class = DefectModel
elif args.goal == "defectenhancementtask":
from bugbug.models.defect_enhancement_task import DefectEnhancementTaskModel
model_class = DefectEnhancementTaskModel
elif args.goal == "regression":
from bugbug.models.regression import RegressionModel
model_class = RegressionModel
elif args.goal == "tracking":
from bugbug.models.tracking import TrackingModel
model_class = TrackingModel
elif args.goal == "qaneeded":
from bugbug.models.qaneeded import QANeededModel
model_class = QANeededModel
elif args.goal == "uplift":
from bugbug.models.uplift import UpliftModel
model_class = UpliftModel
elif args.goal == "component":
if args.goal == "component":
if args.classifier == "default":
from bugbug.models.component import ComponentModel
model_class = ComponentModel
model_class_name = "component"
elif args.classifier == "nn":
from bugbug.models.component_nn import ComponentNNModel
model_class_name = "component_nn"
else:
raise ValueError(f"Unkown value {args.classifier}")
model_class = ComponentNNModel
elif args.goal == "devdocneeded":
from bugbug.models.devdocneeded import DevDocNeededModel
model_class = get_model_class(model_class_name)
model_class = DevDocNeededModel
elif args.goal == "assignee":
from bugbug.models.assignee import AssigneeModel
model_class = AssigneeModel
elif args.goal == "backout":
from bugbug.models.backout import BackoutModel
model_class = BackoutModel
elif args.goal == "bugtype":
from bugbug.models.bugtype import BugTypeModel
model_class = BugTypeModel
elif args.goal == "stepstoreproduce":
from bugbug.models.stepstoreproduce import StepsToReproduceModel
model_class = StepsToReproduceModel
if args.train:
db.download()

Просмотреть файл

@ -4,7 +4,7 @@ import argparse
import sys
from logging import INFO, basicConfig, getLogger
from bugbug.models.component import ComponentModel
from bugbug.models import load_model
basicConfig(level=INFO)
logger = getLogger(__name__)
@ -12,30 +12,14 @@ logger = getLogger(__name__)
class ModelChecker:
def go(self, model_name):
# TODO: Stop hard-coding them
valid_models = ["component"]
if model_name not in valid_models:
exception = f"Invalid model {model_name!r} name, use one of {valid_models!r} instead"
raise ValueError(exception)
# TODO: What is the standard file path of the models?
model_file_name = f"{model_name}model"
if model_name == "component":
model_class = ComponentModel
else:
# We shouldn't be here
raise Exception("valid_models is likely not up-to-date anymore")
# Load the model
model = model_class.load(model_file_name)
model = load_model(model_name)
# Then call the check method of the model
success = model.check()
if not success:
msg = f"Check of model {model_class!r} failed, check the output for reasons why"
msg = f"Check of model {model.__class__!r} failed, check the output for reasons why"
logger.warning(msg)
sys.exit(1)

Просмотреть файл

@ -7,10 +7,7 @@ import shutil
from logging import INFO, basicConfig, getLogger
from urllib.request import urlretrieve
from bugbug.models.component import ComponentModel
from bugbug.models.defect_enhancement_task import DefectEnhancementTaskModel
from bugbug.models.regression import RegressionModel
from bugbug.models.tracking import TrackingModel
from bugbug.models import get_model_class
basicConfig(level=INFO)
logger = getLogger(__name__)
@ -29,40 +26,7 @@ class Trainer(object):
with lzma.open(f"{path}.xz", "wb") as output_f:
shutil.copyfileobj(input_f, output_f)
def train_defect_enhancement_task(self):
logger.info("Training *defect vs enhancement vs task* model")
model = DefectEnhancementTaskModel()
model.train()
self.compress_file("defectenhancementtaskmodel")
def train_component(self):
logger.info("Training *component* model")
model = ComponentModel()
model.train()
self.compress_file("componentmodel")
def train_regression(self):
logger.info("Training *regression vs non-regression* model")
model = RegressionModel()
model.train()
self.compress_file("regressionmodel")
def train_tracking(self):
logger.info("Training *tracking* model")
model = TrackingModel()
model.train()
self.compress_file("trackingmodel")
def go(self, model):
# TODO: Stop hard-coding them
valid_models = ["defect", "component", "regression", "tracking"]
if model not in valid_models:
exception = (
f"Invalid model {model!r} name, use one of {valid_models!r} instead"
)
raise ValueError(exception)
def go(self, model_name):
# Download datasets that were built by bugbug_data.
os.makedirs("data", exist_ok=True)
@ -73,21 +37,14 @@ class Trainer(object):
logger.info("Decompressing bugs database")
self.decompress_file("data/bugs.json")
if model == "defect":
# Train classifier for defect-vs-enhancement-vs-task.
self.train_defect_enhancement_task()
elif model == "component":
# Train classifier for the component of a bug.
self.train_component()
elif model == "regression":
# Train classifier for regression-vs-nonregression.
self.train_regression()
elif model == "tracking":
# Train classifier for tracking bugs.
self.train_tracking()
else:
# We shouldn't be here
raise Exception("valid_models is likely not up-to-date anymore")
logger.info(f"Training *{model_name}* model")
model_class = get_model_class(model_name)
model = model_class()
model.train()
model_file_name = f"{model_name}model"
self.compress_file(model_file_name)
def main():

16
tests/test_models.py Normal file
Просмотреть файл

@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
from bugbug.models import MODELS, get_model_class
def test_import_all_models():
""" Try loading all defined models to ensure that their full qualified
names are still good
"""
for model_name in MODELS.keys():
print("Try loading model", model_name)
get_model_class(model_name)