Add extra functionality from run.py to the trainer script (#856)

Fixes #339
This commit is contained in:
cklyyung 2019-08-05 19:19:59 -04:00 коммит произвёл Marco
Родитель fc31172b84
Коммит 81c17ecbaf
5 изменённых файлов: 77 добавлений и 109 удалений

Просмотреть файл

@ -54,7 +54,7 @@ Every time you will try to commit, pre-commit will run checks on your files to m
## Usage ## Usage
Run the `run.py` script to perform training / classification. The first time `run.py` is executed, the `--train` argument should be used to automatically download databases containing bugs and commits data (they will be downloaded in the data/ directory). Run the `trainer.py` script with the command `python3 -c 'from scripts import trainer; trainer.main()'` to perform training.
### Running the repository mining script ### Running the repository mining script

87
run.py
Просмотреть файл

@ -1,87 +0,0 @@
# -*- coding: utf-8 -*-
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import argparse
import sys
from bugbug import bugzilla, db, repository
from bugbug.models import MODELS, get_model_class
def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument(
"--lemmatization",
help="Perform lemmatization (using spaCy)",
action="store_true",
)
parser.add_argument(
"--training-set-size",
nargs="?",
default=14000,
type=int,
help="The size of the training set for the duplicate model",
)
parser.add_argument(
"--disable-url-cleanup",
help="Don't cleanup urls when training the duplicate model",
dest="cleanup_urls",
default=True,
action="store_false",
)
parser.add_argument("--train", help="Perform training", action="store_true")
parser.add_argument(
"--goal", help="Goal of the classifier", choices=MODELS.keys(), default="defect"
)
parser.add_argument(
"--classifier",
help="Type of the classifier. Only used for component classification.",
choices=["default", "nn"],
default="default",
)
parser.add_argument(
"--historical",
help="""Analyze historical bugs. Only used for defect, bugtype,
defectenhancementtask and regression tasks.""",
action="store_true",
)
return parser.parse_args(args)
def main(args):
if args.goal == "component":
if args.classifier == "default":
model_class_name = "component"
else:
model_class_name = "component_nn"
else:
model_class_name = args.goal
model_class = get_model_class(model_class_name)
if args.train:
db.download(bugzilla.BUGS_DB)
db.download(repository.COMMITS_DB)
historical_supported_tasks = [
"defect",
"bugtype",
"defectenhancementtask",
"regression",
]
if args.goal in historical_supported_tasks:
model = model_class(args.lemmatization, args.historical)
elif args.goal == "duplicate":
model = model_class(
args.training_set_size, args.lemmatization, args.cleanup_urls
)
else:
model = model_class(args.lemmatization)
model.train()
if __name__ == "__main__":
main(parse_args(sys.argv[1:]))

Просмотреть файл

@ -3,33 +3,58 @@
import argparse import argparse
import json import json
import os import os
import sys
from logging import INFO, basicConfig, getLogger from logging import INFO, basicConfig, getLogger
from bugbug import bugzilla, db, model, repository from bugbug import bugzilla, db, model, repository
from bugbug.models import get_model_class from bugbug.models import get_model_class
from bugbug.utils import CustomJsonEncoder, zstd_compress from bugbug.utils import CustomJsonEncoder, zstd_compress
MODELS_WITH_TYPE = ("component",)
HISTORICAL_SUPPORTED_TASKS = (
"defect",
"bugtype",
"defectenhancementtask",
"regression",
)
basicConfig(level=INFO) basicConfig(level=INFO)
logger = getLogger(__name__) logger = getLogger(__name__)
class Trainer(object): class Trainer(object):
def go(self, model_name): def go(self, args):
# Download datasets that were built by bugbug_data. # Download datasets that were built by bugbug_data.
os.makedirs("data", exist_ok=True) os.makedirs("data", exist_ok=True)
if args.classifier != "default":
assert (
args.model in MODELS_WITH_TYPE
), f"{args.classifier} is not a valid classifier type for {args.model}"
model_name = f"{args.model}_{args.classifier}"
else:
model_name = args.model
model_class = get_model_class(model_name) model_class = get_model_class(model_name)
model_obj = model_class() if args.model in HISTORICAL_SUPPORTED_TASKS:
model_obj = model_class(args.lemmatization, args.historical)
elif args.model == "duplicate":
model_obj = model_class(
args.training_set_size, args.lemmatization, args.cleanup_urls
)
else:
model_obj = model_class(args.lemmatization)
if ( if (
isinstance(model_obj, model.BugModel) isinstance(model_obj, model.BugModel)
or isinstance(model_obj, model.BugCoupleModel) or isinstance(model_obj, model.BugCoupleModel)
or (hasattr(model_obj, "bug_data") and model_obj.bug_data) or (hasattr(model_obj, "bug_data") and model_obj.bug_data)
): ):
db.download(bugzilla.BUGS_DB, force=True) db.download(bugzilla.BUGS_DB)
if isinstance(model_obj, model.CommitModel): if isinstance(model_obj, model.CommitModel):
db.download(repository.COMMITS_DB, force=True) db.download(repository.COMMITS_DB)
logger.info(f"Training *{model_name}* model") logger.info(f"Training *{model_name}* model")
metrics = model_obj.train() metrics = model_obj.train()
@ -48,13 +73,47 @@ class Trainer(object):
logger.info(f"Model compressed") logger.info(f"Model compressed")
def main(): def parse_args(args):
description = "Train the models" description = "Train the models"
parser = argparse.ArgumentParser(description=description) parser = argparse.ArgumentParser(description=description)
parser.add_argument("model", help="Which model to train.") parser.add_argument("model", help="Which model to train.")
parser.add_argument(
"--lemmatization",
help="Perform lemmatization (using spaCy)",
action="store_true",
)
parser.add_argument(
"--training-set-size",
nargs="?",
default=14000,
type=int,
help="The size of the training set for the duplicate model",
)
parser.add_argument(
"--disable-url-cleanup",
help="Don't cleanup urls when training the duplicate model",
dest="cleanup_urls",
default=True,
action="store_false",
)
parser.add_argument(
"--classifier",
help="Type of the classifier. Only used for component classification.",
choices=["default", "nn"],
default="default",
)
parser.add_argument(
"--historical",
help="""Analyze historical bugs. Only used for defect, bugtype,
defectenhancementtask and regression tasks.""",
action="store_true",
)
return parser.parse_args(args)
args = parser.parse_args()
def main():
args = parse_args(sys.argv[1:])
retriever = Trainer() retriever = Trainer()
retriever.go(args.model) retriever.go(args)

Просмотреть файл

@ -1,14 +0,0 @@
# -*- coding: utf-8 -*-
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import run
def test_run():
# Test running the training for the bug model.
run.main(run.parse_args(["--train", "--goal", "defect"]))
# Test loading the trained model.
run.main(run.parse_args(["--goal", "defect"]))

10
tests/test_trainer.py Normal file
Просмотреть файл

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
from scripts import trainer
def test_trainer():
trainer.Trainer().go(trainer.parse_args(["defect"]))