Import Trainer class from release-services repository (#254)

* Import Trainer class from release-services repository

This basically import the `trainer.py` file from the `release-services`
repository at hash 77cdddd. I removed imports and reference to cli-common
helpers that will likely need to be reimplemented, like the raven support.

Also defines 4 docker images, one per model to train.

* Remove unused imports
This commit is contained in:
Boris Feld 2019-04-09 17:49:56 +02:00 коммит произвёл Marco
Родитель b651744b18
Коммит 6af6e8b927
3 изменённых файлов: 132 добавлений и 0 удалений

Просмотреть файл

@ -31,3 +31,27 @@ services:
target: /cache/
volume:
nocopy: true
bugbug-train-component:
build:
context: .
dockerfile: infra/dockerfile.train_component
image: mozilla/bugbug-train-component
bugbug-train-defect:
build:
context: .
dockerfile: infra/dockerfile.train_defect
image: mozilla/bugbug-train-defect
bugbug-train-regression:
build:
context: .
dockerfile: infra/dockerfile.train_regression
image: mozilla/bugbug-train-regression
bugbug-train-tracking:
build:
context: .
dockerfile: infra/dockerfile.train_tracking
image: mozilla/bugbug-train-tracking

107
scripts/trainer.py Normal file
Просмотреть файл

@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
import argparse
import lzma
import os
import shutil
from logging import INFO, basicConfig, getLogger
from urllib.request import urlretrieve
from bugbug.models.component import ComponentModel
from bugbug.models.defect_enhancement_task import DefectEnhancementTaskModel
from bugbug.models.regression import RegressionModel
from bugbug.models.tracking import TrackingModel
basicConfig(level=INFO)
logger = getLogger(__name__)
BASE_URL = "https://index.taskcluster.net/v1/task/project.releng.services.project.testing.bugbug_data.latest/artifacts/public"
class Trainer(object):
def decompress_file(self, path):
with lzma.open(f"{path}.xz", "rb") as input_f:
with open(path, "wb") as output_f:
shutil.copyfileobj(input_f, output_f)
def compress_file(self, path):
with open(path, "rb") as input_f:
with lzma.open(f"{path}.xz", "wb") as output_f:
shutil.copyfileobj(input_f, output_f)
def train_defect_enhancement_task(self):
logger.info("Training *defect vs enhancement vs task* model")
model = DefectEnhancementTaskModel()
model.train()
self.compress_file("defectenhancementtaskmodel")
def train_component(self):
logger.info("Training *component* model")
model = ComponentModel()
model.train()
self.compress_file("componentmodel")
def train_regression(self):
logger.info("Training *regression vs non-regression* model")
model = RegressionModel()
model.train()
self.compress_file("regressionmodel")
def train_tracking(self):
logger.info("Training *tracking* model")
model = TrackingModel()
model.train()
self.compress_file("trackingmodel")
def go(self, model):
# TODO: Stop hard-coding them
valid_models = ["defect", "component", "regression", "tracking"]
if model not in valid_models:
exception = (
f"Invalid model {model!r} name, use one of {valid_models!r} instead"
)
raise ValueError(exception)
# Download datasets that were built by bugbug_data.
os.makedirs("data", exist_ok=True)
# Bugs.json
logger.info("Downloading bugs database")
urlretrieve(f"{BASE_URL}/bugs.json.xz", "data/bugs.json.xz")
logger.info("Decompressing bugs database")
self.decompress_file("data/bugs.json")
# Commits.json
logger.info("Downloading commits database")
urlretrieve(f"{BASE_URL}/commits.json.xz", "data/commits.json.xz")
logger.info("Decompressing commits database")
self.decompress_file("data/commits.json")
if model == "defect":
# Train classifier for defect-vs-enhancement-vs-task.
self.train_defect_enhancement_task()
elif model == "component":
# Train classifier for the component of a bug.
self.train_component()
elif model == "regression":
# Train classifier for regression-vs-nonregression.
self.train_regression()
elif model == "tracking":
# Train classifier for tracking bugs.
self.train_tracking()
else:
# We shouldn't be here
raise Exception("valid_models is likely not up-to-date anymore")
def main():
description = "Train the models"
parser = argparse.ArgumentParser(description=description)
parser.add_argument("model", help="Which model to train.")
args = parser.parse_args()
retriever = Trainer()
retriever.go(args.model)

Просмотреть файл

@ -52,6 +52,7 @@ setup(
"console_scripts": [
"bugbug-data-commits = scripts.commit_retriever:main",
"bugbug-data-bugzilla = scripts.bug_retriever:main",
"bugbug-train = scripts.trainer:main",
]
},
)