Add extra functionality from run.py to the trainer script (#856)

Fixes #339
This commit is contained in:
cklyyung 2019-08-05 19:19:59 -04:00 коммит произвёл Marco
Родитель fc31172b84
Коммит 81c17ecbaf
5 изменённых файлов: 77 добавлений и 109 удалений

Просмотреть файл

@ -54,7 +54,7 @@ Every time you will try to commit, pre-commit will run checks on your files to m
## Usage
Run the `run.py` script to perform training / classification. The first time `run.py` is executed, the `--train` argument should be used to automatically download databases containing bugs and commits data (they will be downloaded in the data/ directory).
Run the `trainer.py` script with the command `python3 -c 'from scripts import trainer; trainer.main()'` to perform training.
### Running the repository mining script

87
run.py
Просмотреть файл

@ -1,87 +0,0 @@
# -*- coding: utf-8 -*-
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import argparse
import sys
from bugbug import bugzilla, db, repository
from bugbug.models import MODELS, get_model_class
def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument(
"--lemmatization",
help="Perform lemmatization (using spaCy)",
action="store_true",
)
parser.add_argument(
"--training-set-size",
nargs="?",
default=14000,
type=int,
help="The size of the training set for the duplicate model",
)
parser.add_argument(
"--disable-url-cleanup",
help="Don't cleanup urls when training the duplicate model",
dest="cleanup_urls",
default=True,
action="store_false",
)
parser.add_argument("--train", help="Perform training", action="store_true")
parser.add_argument(
"--goal", help="Goal of the classifier", choices=MODELS.keys(), default="defect"
)
parser.add_argument(
"--classifier",
help="Type of the classifier. Only used for component classification.",
choices=["default", "nn"],
default="default",
)
parser.add_argument(
"--historical",
help="""Analyze historical bugs. Only used for defect, bugtype,
defectenhancementtask and regression tasks.""",
action="store_true",
)
return parser.parse_args(args)
def main(args):
if args.goal == "component":
if args.classifier == "default":
model_class_name = "component"
else:
model_class_name = "component_nn"
else:
model_class_name = args.goal
model_class = get_model_class(model_class_name)
if args.train:
db.download(bugzilla.BUGS_DB)
db.download(repository.COMMITS_DB)
historical_supported_tasks = [
"defect",
"bugtype",
"defectenhancementtask",
"regression",
]
if args.goal in historical_supported_tasks:
model = model_class(args.lemmatization, args.historical)
elif args.goal == "duplicate":
model = model_class(
args.training_set_size, args.lemmatization, args.cleanup_urls
)
else:
model = model_class(args.lemmatization)
model.train()
if __name__ == "__main__":
main(parse_args(sys.argv[1:]))

Просмотреть файл

@ -3,33 +3,58 @@
import argparse
import json
import os
import sys
from logging import INFO, basicConfig, getLogger
from bugbug import bugzilla, db, model, repository
from bugbug.models import get_model_class
from bugbug.utils import CustomJsonEncoder, zstd_compress
MODELS_WITH_TYPE = ("component",)
HISTORICAL_SUPPORTED_TASKS = (
"defect",
"bugtype",
"defectenhancementtask",
"regression",
)
basicConfig(level=INFO)
logger = getLogger(__name__)
class Trainer(object):
def go(self, model_name):
def go(self, args):
# Download datasets that were built by bugbug_data.
os.makedirs("data", exist_ok=True)
if args.classifier != "default":
assert (
args.model in MODELS_WITH_TYPE
), f"{args.classifier} is not a valid classifier type for {args.model}"
model_name = f"{args.model}_{args.classifier}"
else:
model_name = args.model
model_class = get_model_class(model_name)
model_obj = model_class()
if args.model in HISTORICAL_SUPPORTED_TASKS:
model_obj = model_class(args.lemmatization, args.historical)
elif args.model == "duplicate":
model_obj = model_class(
args.training_set_size, args.lemmatization, args.cleanup_urls
)
else:
model_obj = model_class(args.lemmatization)
if (
isinstance(model_obj, model.BugModel)
or isinstance(model_obj, model.BugCoupleModel)
or (hasattr(model_obj, "bug_data") and model_obj.bug_data)
):
db.download(bugzilla.BUGS_DB, force=True)
db.download(bugzilla.BUGS_DB)
if isinstance(model_obj, model.CommitModel):
db.download(repository.COMMITS_DB, force=True)
db.download(repository.COMMITS_DB)
logger.info(f"Training *{model_name}* model")
metrics = model_obj.train()
@ -48,13 +73,47 @@ class Trainer(object):
logger.info(f"Model compressed")
def main():
def parse_args(args):
description = "Train the models"
parser = argparse.ArgumentParser(description=description)
parser.add_argument("model", help="Which model to train.")
parser.add_argument(
"--lemmatization",
help="Perform lemmatization (using spaCy)",
action="store_true",
)
parser.add_argument(
"--training-set-size",
nargs="?",
default=14000,
type=int,
help="The size of the training set for the duplicate model",
)
parser.add_argument(
"--disable-url-cleanup",
help="Don't cleanup urls when training the duplicate model",
dest="cleanup_urls",
default=True,
action="store_false",
)
parser.add_argument(
"--classifier",
help="Type of the classifier. Only used for component classification.",
choices=["default", "nn"],
default="default",
)
parser.add_argument(
"--historical",
help="""Analyze historical bugs. Only used for defect, bugtype,
defectenhancementtask and regression tasks.""",
action="store_true",
)
return parser.parse_args(args)
args = parser.parse_args()
def main():
args = parse_args(sys.argv[1:])
retriever = Trainer()
retriever.go(args.model)
retriever.go(args)

Просмотреть файл

@ -1,14 +0,0 @@
# -*- coding: utf-8 -*-
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import run
def test_run():
# Test running the training for the bug model.
run.main(run.parse_args(["--train", "--goal", "defect"]))
# Test loading the trained model.
run.main(run.parse_args(["--goal", "defect"]))

10
tests/test_trainer.py Normal file
Просмотреть файл

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
from scripts import trainer
def test_trainer():
trainer.Trainer().go(trainer.parse_args(["defect"]))