зеркало из https://github.com/mozilla/bugbug.git
Add a central place where the models are defined (#398)
* Add a central place where the models are defined Also add some helpers to load a model. * Add missing tensorflow dependency in extra-nn-requirements.txt
This commit is contained in:
Родитель
78e6a8175a
Коммит
0a5e37439d
|
@ -90,6 +90,8 @@ tasks:
|
|||
cd bugbug &&
|
||||
git checkout ${head_rev} &&
|
||||
pip install -r requirements.txt &&
|
||||
pip install -r extra-nlp-requirements.txt &&
|
||||
pip install -r extra-nn-requirements.txt &&
|
||||
pip install -r test-requirements.txt &&
|
||||
python -m pytest tests/test_*.py"
|
||||
metadata:
|
||||
|
|
|
@ -1 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
|
|
@ -1 +1,57 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
LOGGER = logging.getLogger()
|
||||
|
||||
|
||||
MODELS = {
|
||||
"assignee": "bugbug.models.assignee.AssigneeModel",
|
||||
"backout": "bugbug.models.backout.BackoutModel",
|
||||
"bug": "bugbug.model.BugModel",
|
||||
"bugtype": "bugbug.models.bugtype.BugTypeModel",
|
||||
"component": "bugbug.models.component.ComponentModel",
|
||||
"component_nn": "bugbug.models.component_nn.ComponentNNModel",
|
||||
"defect": "bugbug.models.defect.DefectModel",
|
||||
"defectenhancementtask": "bugbug.models.defect_enhancement_task.DefectEnhancementTaskModel",
|
||||
"devdocneeded": "bugbug.models.devdocneeded.DevDocNeededModel",
|
||||
"qaneeded": "bugbug.models.qaneeded.QANeededModel",
|
||||
"regression": "bugbug.models.regression.RegressionModel",
|
||||
"stepstoreproduce": "bugbug.models.stepstoreproduce.StepsToReproduceModel",
|
||||
"tracking": "bugbug.models.tracking.TrackingModel",
|
||||
"uplift": "bugbug.models.uplift.UpliftModel",
|
||||
}
|
||||
|
||||
|
||||
def load_model_class(full_qualified_class_name):
|
||||
""" Load the class dynamically in order to speed up the boot and allow for
|
||||
dynamic optional dependencies to be declared and check at import time
|
||||
"""
|
||||
module_name, class_name = full_qualified_class_name.rsplit(".", 1)
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def get_model_class(model_name):
|
||||
if model_name not in MODELS:
|
||||
err_msg = f"Invalid name {model_name}, not in {list(MODELS.keys())}"
|
||||
raise ValueError(err_msg)
|
||||
|
||||
full_qualified_class_name = MODELS[model_name]
|
||||
return load_model_class(full_qualified_class_name)
|
||||
|
||||
|
||||
def load_model(model_name, model_dir=None):
|
||||
model_class = get_model_class(model_name)
|
||||
|
||||
if model_dir is None:
|
||||
model_dir = "."
|
||||
|
||||
model_file_path = os.path.join(model_dir, f"{model_name}model")
|
||||
|
||||
LOGGER.info(f"Lookup model in {model_file_path}")
|
||||
model = model_class.load(model_file_path)
|
||||
return model
|
||||
|
|
|
@ -12,11 +12,9 @@ OPT_MSG_MISSING = (
|
|||
)
|
||||
|
||||
try:
|
||||
from keras.preprocessing.sequence import pad_sequences
|
||||
from keras.preprocessing.text import Tokenizer
|
||||
from keras.preprocessing.sequence import pad_sequences
|
||||
from keras.utils import to_categorical
|
||||
|
||||
HAS_OPTIONAL_DEPENDENCIES = True
|
||||
except ImportError:
|
||||
raise ImportError(OPT_MSG_MISSING)
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import os
|
|||
import random
|
||||
|
||||
from bugbug import bugzilla
|
||||
from bugbug.models import load_model
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
@ -20,13 +21,9 @@ parser.add_argument(
|
|||
args = parser.parse_args()
|
||||
|
||||
if args.goal == "str":
|
||||
from bugbug.models.bug import BugModel
|
||||
|
||||
model = BugModel.load("bugmodel")
|
||||
model = load_model("bug")
|
||||
elif args.goal == "regressionrange":
|
||||
from bugbug.models.regression import RegressionModel
|
||||
|
||||
model = RegressionModel.load("regressionmodel")
|
||||
model = load_model("regression")
|
||||
|
||||
file_path = os.path.join("bugbug", "labels", f"{args.goal}.csv")
|
||||
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
keras==2.2.4
|
||||
tensorflow==1.13.1
|
56
run.py
56
run.py
|
@ -12,6 +12,7 @@ import numpy as np
|
|||
|
||||
from bugbug import repository # noqa
|
||||
from bugbug import bugzilla, db
|
||||
from bugbug.models import get_model_class
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -64,59 +65,18 @@ if __name__ == "__main__":
|
|||
args.goal, "" if args.classifier == "default" else args.classifier
|
||||
)
|
||||
|
||||
if args.goal == "defect":
|
||||
from bugbug.models.defect import DefectModel
|
||||
model_class_name = args.goal
|
||||
|
||||
model_class = DefectModel
|
||||
elif args.goal == "defectenhancementtask":
|
||||
from bugbug.models.defect_enhancement_task import DefectEnhancementTaskModel
|
||||
|
||||
model_class = DefectEnhancementTaskModel
|
||||
elif args.goal == "regression":
|
||||
from bugbug.models.regression import RegressionModel
|
||||
|
||||
model_class = RegressionModel
|
||||
elif args.goal == "tracking":
|
||||
from bugbug.models.tracking import TrackingModel
|
||||
|
||||
model_class = TrackingModel
|
||||
elif args.goal == "qaneeded":
|
||||
from bugbug.models.qaneeded import QANeededModel
|
||||
|
||||
model_class = QANeededModel
|
||||
elif args.goal == "uplift":
|
||||
from bugbug.models.uplift import UpliftModel
|
||||
|
||||
model_class = UpliftModel
|
||||
elif args.goal == "component":
|
||||
if args.goal == "component":
|
||||
if args.classifier == "default":
|
||||
from bugbug.models.component import ComponentModel
|
||||
|
||||
model_class = ComponentModel
|
||||
model_class_name = "component"
|
||||
elif args.classifier == "nn":
|
||||
from bugbug.models.component_nn import ComponentNNModel
|
||||
model_class_name = "component_nn"
|
||||
else:
|
||||
raise ValueError(f"Unkown value {args.classifier}")
|
||||
|
||||
model_class = ComponentNNModel
|
||||
elif args.goal == "devdocneeded":
|
||||
from bugbug.models.devdocneeded import DevDocNeededModel
|
||||
model_class = get_model_class(model_class_name)
|
||||
|
||||
model_class = DevDocNeededModel
|
||||
elif args.goal == "assignee":
|
||||
from bugbug.models.assignee import AssigneeModel
|
||||
|
||||
model_class = AssigneeModel
|
||||
elif args.goal == "backout":
|
||||
from bugbug.models.backout import BackoutModel
|
||||
|
||||
model_class = BackoutModel
|
||||
elif args.goal == "bugtype":
|
||||
from bugbug.models.bugtype import BugTypeModel
|
||||
|
||||
model_class = BugTypeModel
|
||||
elif args.goal == "stepstoreproduce":
|
||||
from bugbug.models.stepstoreproduce import StepsToReproduceModel
|
||||
|
||||
model_class = StepsToReproduceModel
|
||||
if args.train:
|
||||
db.download()
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import argparse
|
|||
import sys
|
||||
from logging import INFO, basicConfig, getLogger
|
||||
|
||||
from bugbug.models.component import ComponentModel
|
||||
from bugbug.models import load_model
|
||||
|
||||
basicConfig(level=INFO)
|
||||
logger = getLogger(__name__)
|
||||
|
@ -12,30 +12,14 @@ logger = getLogger(__name__)
|
|||
|
||||
class ModelChecker:
|
||||
def go(self, model_name):
|
||||
# TODO: Stop hard-coding them
|
||||
valid_models = ["component"]
|
||||
|
||||
if model_name not in valid_models:
|
||||
exception = f"Invalid model {model_name!r} name, use one of {valid_models!r} instead"
|
||||
raise ValueError(exception)
|
||||
|
||||
# TODO: What is the standard file path of the models?
|
||||
model_file_name = f"{model_name}model"
|
||||
|
||||
if model_name == "component":
|
||||
model_class = ComponentModel
|
||||
else:
|
||||
# We shouldn't be here
|
||||
raise Exception("valid_models is likely not up-to-date anymore")
|
||||
|
||||
# Load the model
|
||||
model = model_class.load(model_file_name)
|
||||
model = load_model(model_name)
|
||||
|
||||
# Then call the check method of the model
|
||||
success = model.check()
|
||||
|
||||
if not success:
|
||||
msg = f"Check of model {model_class!r} failed, check the output for reasons why"
|
||||
msg = f"Check of model {model.__class__!r} failed, check the output for reasons why"
|
||||
logger.warning(msg)
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
@ -7,10 +7,7 @@ import shutil
|
|||
from logging import INFO, basicConfig, getLogger
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
from bugbug.models.component import ComponentModel
|
||||
from bugbug.models.defect_enhancement_task import DefectEnhancementTaskModel
|
||||
from bugbug.models.regression import RegressionModel
|
||||
from bugbug.models.tracking import TrackingModel
|
||||
from bugbug.models import get_model_class
|
||||
|
||||
basicConfig(level=INFO)
|
||||
logger = getLogger(__name__)
|
||||
|
@ -29,40 +26,7 @@ class Trainer(object):
|
|||
with lzma.open(f"{path}.xz", "wb") as output_f:
|
||||
shutil.copyfileobj(input_f, output_f)
|
||||
|
||||
def train_defect_enhancement_task(self):
|
||||
logger.info("Training *defect vs enhancement vs task* model")
|
||||
model = DefectEnhancementTaskModel()
|
||||
model.train()
|
||||
self.compress_file("defectenhancementtaskmodel")
|
||||
|
||||
def train_component(self):
|
||||
logger.info("Training *component* model")
|
||||
model = ComponentModel()
|
||||
model.train()
|
||||
self.compress_file("componentmodel")
|
||||
|
||||
def train_regression(self):
|
||||
logger.info("Training *regression vs non-regression* model")
|
||||
model = RegressionModel()
|
||||
model.train()
|
||||
self.compress_file("regressionmodel")
|
||||
|
||||
def train_tracking(self):
|
||||
logger.info("Training *tracking* model")
|
||||
model = TrackingModel()
|
||||
model.train()
|
||||
self.compress_file("trackingmodel")
|
||||
|
||||
def go(self, model):
|
||||
# TODO: Stop hard-coding them
|
||||
valid_models = ["defect", "component", "regression", "tracking"]
|
||||
|
||||
if model not in valid_models:
|
||||
exception = (
|
||||
f"Invalid model {model!r} name, use one of {valid_models!r} instead"
|
||||
)
|
||||
raise ValueError(exception)
|
||||
|
||||
def go(self, model_name):
|
||||
# Download datasets that were built by bugbug_data.
|
||||
os.makedirs("data", exist_ok=True)
|
||||
|
||||
|
@ -73,21 +37,14 @@ class Trainer(object):
|
|||
logger.info("Decompressing bugs database")
|
||||
self.decompress_file("data/bugs.json")
|
||||
|
||||
if model == "defect":
|
||||
# Train classifier for defect-vs-enhancement-vs-task.
|
||||
self.train_defect_enhancement_task()
|
||||
elif model == "component":
|
||||
# Train classifier for the component of a bug.
|
||||
self.train_component()
|
||||
elif model == "regression":
|
||||
# Train classifier for regression-vs-nonregression.
|
||||
self.train_regression()
|
||||
elif model == "tracking":
|
||||
# Train classifier for tracking bugs.
|
||||
self.train_tracking()
|
||||
else:
|
||||
# We shouldn't be here
|
||||
raise Exception("valid_models is likely not up-to-date anymore")
|
||||
logger.info(f"Training *{model_name}* model")
|
||||
|
||||
model_class = get_model_class(model_name)
|
||||
model = model_class()
|
||||
model.train()
|
||||
|
||||
model_file_name = f"{model_name}model"
|
||||
self.compress_file(model_file_name)
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
from bugbug.models import MODELS, get_model_class
|
||||
|
||||
|
||||
def test_import_all_models():
|
||||
""" Try loading all defined models to ensure that their full qualified
|
||||
names are still good
|
||||
"""
|
||||
|
||||
for model_name in MODELS.keys():
|
||||
print("Try loading model", model_name)
|
||||
get_model_class(model_name)
|
Загрузка…
Ссылка в новой задаче