зеркало из https://github.com/mozilla/bugbug.git
Import Trainer class from release-services repository (#254)
* Import Trainer class from release-services repository This basically import the `trainer.py` file from the `release-services` repository at hash 77cdddd. I removed imports and reference to cli-common helpers that will likely need to be reimplemented, like the raven support. Also defines 4 docker images, one per model to train. * Remove unused imports
This commit is contained in:
Родитель
b651744b18
Коммит
6af6e8b927
|
@ -31,3 +31,27 @@ services:
|
|||
target: /cache/
|
||||
volume:
|
||||
nocopy: true
|
||||
|
||||
bugbug-train-component:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: infra/dockerfile.train_component
|
||||
image: mozilla/bugbug-train-component
|
||||
|
||||
bugbug-train-defect:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: infra/dockerfile.train_defect
|
||||
image: mozilla/bugbug-train-defect
|
||||
|
||||
bugbug-train-regression:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: infra/dockerfile.train_regression
|
||||
image: mozilla/bugbug-train-regression
|
||||
|
||||
bugbug-train-tracking:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: infra/dockerfile.train_tracking
|
||||
image: mozilla/bugbug-train-tracking
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import lzma
|
||||
import os
|
||||
import shutil
|
||||
from logging import INFO, basicConfig, getLogger
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
from bugbug.models.component import ComponentModel
|
||||
from bugbug.models.defect_enhancement_task import DefectEnhancementTaskModel
|
||||
from bugbug.models.regression import RegressionModel
|
||||
from bugbug.models.tracking import TrackingModel
|
||||
|
||||
basicConfig(level=INFO)
|
||||
logger = getLogger(__name__)
|
||||
|
||||
BASE_URL = "https://index.taskcluster.net/v1/task/project.releng.services.project.testing.bugbug_data.latest/artifacts/public"
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def decompress_file(self, path):
|
||||
with lzma.open(f"{path}.xz", "rb") as input_f:
|
||||
with open(path, "wb") as output_f:
|
||||
shutil.copyfileobj(input_f, output_f)
|
||||
|
||||
def compress_file(self, path):
|
||||
with open(path, "rb") as input_f:
|
||||
with lzma.open(f"{path}.xz", "wb") as output_f:
|
||||
shutil.copyfileobj(input_f, output_f)
|
||||
|
||||
def train_defect_enhancement_task(self):
|
||||
logger.info("Training *defect vs enhancement vs task* model")
|
||||
model = DefectEnhancementTaskModel()
|
||||
model.train()
|
||||
self.compress_file("defectenhancementtaskmodel")
|
||||
|
||||
def train_component(self):
|
||||
logger.info("Training *component* model")
|
||||
model = ComponentModel()
|
||||
model.train()
|
||||
self.compress_file("componentmodel")
|
||||
|
||||
def train_regression(self):
|
||||
logger.info("Training *regression vs non-regression* model")
|
||||
model = RegressionModel()
|
||||
model.train()
|
||||
self.compress_file("regressionmodel")
|
||||
|
||||
def train_tracking(self):
|
||||
logger.info("Training *tracking* model")
|
||||
model = TrackingModel()
|
||||
model.train()
|
||||
self.compress_file("trackingmodel")
|
||||
|
||||
def go(self, model):
|
||||
# TODO: Stop hard-coding them
|
||||
valid_models = ["defect", "component", "regression", "tracking"]
|
||||
|
||||
if model not in valid_models:
|
||||
exception = (
|
||||
f"Invalid model {model!r} name, use one of {valid_models!r} instead"
|
||||
)
|
||||
raise ValueError(exception)
|
||||
|
||||
# Download datasets that were built by bugbug_data.
|
||||
os.makedirs("data", exist_ok=True)
|
||||
|
||||
# Bugs.json
|
||||
logger.info("Downloading bugs database")
|
||||
urlretrieve(f"{BASE_URL}/bugs.json.xz", "data/bugs.json.xz")
|
||||
logger.info("Decompressing bugs database")
|
||||
self.decompress_file("data/bugs.json")
|
||||
|
||||
# Commits.json
|
||||
logger.info("Downloading commits database")
|
||||
urlretrieve(f"{BASE_URL}/commits.json.xz", "data/commits.json.xz")
|
||||
logger.info("Decompressing commits database")
|
||||
self.decompress_file("data/commits.json")
|
||||
|
||||
if model == "defect":
|
||||
# Train classifier for defect-vs-enhancement-vs-task.
|
||||
self.train_defect_enhancement_task()
|
||||
elif model == "component":
|
||||
# Train classifier for the component of a bug.
|
||||
self.train_component()
|
||||
elif model == "regression":
|
||||
# Train classifier for regression-vs-nonregression.
|
||||
self.train_regression()
|
||||
elif model == "tracking":
|
||||
# Train classifier for tracking bugs.
|
||||
self.train_tracking()
|
||||
else:
|
||||
# We shouldn't be here
|
||||
raise Exception("valid_models is likely not up-to-date anymore")
|
||||
|
||||
|
||||
def main():
|
||||
description = "Train the models"
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
|
||||
parser.add_argument("model", help="Which model to train.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
retriever = Trainer()
|
||||
retriever.go(args.model)
|
1
setup.py
1
setup.py
|
@ -52,6 +52,7 @@ setup(
|
|||
"console_scripts": [
|
||||
"bugbug-data-commits = scripts.commit_retriever:main",
|
||||
"bugbug-data-bugzilla = scripts.bug_retriever:main",
|
||||
"bugbug-train = scripts.trainer:main",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче