From 81c17ecbafa87d25ff667e7a79c1f6a2dd689c21 Mon Sep 17 00:00:00 2001 From: cklyyung Date: Mon, 5 Aug 2019 19:19:59 -0400 Subject: [PATCH] Add extra functionality from run.py to the trainer script (#856) Fixes #339 --- README.md | 2 +- run.py | 87 ------------------------------------------- scripts/trainer.py | 73 ++++++++++++++++++++++++++++++++---- tests/test_run.py | 14 ------- tests/test_trainer.py | 10 +++++ 5 files changed, 77 insertions(+), 109 deletions(-) delete mode 100644 run.py delete mode 100644 tests/test_run.py create mode 100644 tests/test_trainer.py diff --git a/README.md b/README.md index f0fd345f..6ce1a297 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Every time you will try to commit, pre-commit will run checks on your files to m ## Usage -Run the `run.py` script to perform training / classification. The first time `run.py` is executed, the `--train` argument should be used to automatically download databases containing bugs and commits data (they will be downloaded in the data/ directory). +Run the `trainer.py` script with the command `python3 -c 'from scripts import trainer; trainer.main()'` to perform training. ### Running the repository mining script diff --git a/run.py b/run.py deleted file mode 100644 index 6f92ff6f..00000000 --- a/run.py +++ /dev/null @@ -1,87 +0,0 @@ -# -*- coding: utf-8 -*- -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this file, -# You can obtain one at http://mozilla.org/MPL/2.0/. - -import argparse -import sys - -from bugbug import bugzilla, db, repository -from bugbug.models import MODELS, get_model_class - - -def parse_args(args): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lemmatization", - help="Perform lemmatization (using spaCy)", - action="store_true", - ) - parser.add_argument( - "--training-set-size", - nargs="?", - default=14000, - type=int, - help="The size of the training set for the duplicate model", - ) - parser.add_argument( - "--disable-url-cleanup", - help="Don't cleanup urls when training the duplicate model", - dest="cleanup_urls", - default=True, - action="store_false", - ) - parser.add_argument("--train", help="Perform training", action="store_true") - parser.add_argument( - "--goal", help="Goal of the classifier", choices=MODELS.keys(), default="defect" - ) - parser.add_argument( - "--classifier", - help="Type of the classifier. Only used for component classification.", - choices=["default", "nn"], - default="default", - ) - parser.add_argument( - "--historical", - help="""Analyze historical bugs. Only used for defect, bugtype, - defectenhancementtask and regression tasks.""", - action="store_true", - ) - return parser.parse_args(args) - - -def main(args): - if args.goal == "component": - if args.classifier == "default": - model_class_name = "component" - else: - model_class_name = "component_nn" - else: - model_class_name = args.goal - - model_class = get_model_class(model_class_name) - - if args.train: - db.download(bugzilla.BUGS_DB) - db.download(repository.COMMITS_DB) - - historical_supported_tasks = [ - "defect", - "bugtype", - "defectenhancementtask", - "regression", - ] - - if args.goal in historical_supported_tasks: - model = model_class(args.lemmatization, args.historical) - elif args.goal == "duplicate": - model = model_class( - args.training_set_size, args.lemmatization, args.cleanup_urls - ) - else: - model = model_class(args.lemmatization) - model.train() - - -if __name__ == "__main__": - main(parse_args(sys.argv[1:])) diff --git a/scripts/trainer.py b/scripts/trainer.py index 66ba6e5d..03da9bd8 100644 --- a/scripts/trainer.py +++ b/scripts/trainer.py @@ -3,33 +3,58 @@ import argparse import json import os +import sys from logging import INFO, basicConfig, getLogger from bugbug import bugzilla, db, model, repository from bugbug.models import get_model_class from bugbug.utils import CustomJsonEncoder, zstd_compress +MODELS_WITH_TYPE = ("component",) +HISTORICAL_SUPPORTED_TASKS = ( + "defect", + "bugtype", + "defectenhancementtask", + "regression", +) + basicConfig(level=INFO) logger = getLogger(__name__) class Trainer(object): - def go(self, model_name): + def go(self, args): # Download datasets that were built by bugbug_data. os.makedirs("data", exist_ok=True) + if args.classifier != "default": + assert ( + args.model in MODELS_WITH_TYPE + ), f"{args.classifier} is not a valid classifier type for {args.model}" + + model_name = f"{args.model}_{args.classifier}" + else: + model_name = args.model + model_class = get_model_class(model_name) - model_obj = model_class() + if args.model in HISTORICAL_SUPPORTED_TASKS: + model_obj = model_class(args.lemmatization, args.historical) + elif args.model == "duplicate": + model_obj = model_class( + args.training_set_size, args.lemmatization, args.cleanup_urls + ) + else: + model_obj = model_class(args.lemmatization) if ( isinstance(model_obj, model.BugModel) or isinstance(model_obj, model.BugCoupleModel) or (hasattr(model_obj, "bug_data") and model_obj.bug_data) ): - db.download(bugzilla.BUGS_DB, force=True) + db.download(bugzilla.BUGS_DB) if isinstance(model_obj, model.CommitModel): - db.download(repository.COMMITS_DB, force=True) + db.download(repository.COMMITS_DB) logger.info(f"Training *{model_name}* model") metrics = model_obj.train() @@ -48,13 +73,47 @@ class Trainer(object): logger.info(f"Model compressed") -def main(): +def parse_args(args): description = "Train the models" parser = argparse.ArgumentParser(description=description) parser.add_argument("model", help="Which model to train.") + parser.add_argument( + "--lemmatization", + help="Perform lemmatization (using spaCy)", + action="store_true", + ) + parser.add_argument( + "--training-set-size", + nargs="?", + default=14000, + type=int, + help="The size of the training set for the duplicate model", + ) + parser.add_argument( + "--disable-url-cleanup", + help="Don't cleanup urls when training the duplicate model", + dest="cleanup_urls", + default=True, + action="store_false", + ) + parser.add_argument( + "--classifier", + help="Type of the classifier. Only used for component classification.", + choices=["default", "nn"], + default="default", + ) + parser.add_argument( + "--historical", + help="""Analyze historical bugs. Only used for defect, bugtype, + defectenhancementtask and regression tasks.""", + action="store_true", + ) + return parser.parse_args(args) - args = parser.parse_args() + +def main(): + args = parse_args(sys.argv[1:]) retriever = Trainer() - retriever.go(args.model) + retriever.go(args) diff --git a/tests/test_run.py b/tests/test_run.py deleted file mode 100644 index 3d71a9d1..00000000 --- a/tests/test_run.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this file, -# You can obtain one at http://mozilla.org/MPL/2.0/. - -import run - - -def test_run(): - # Test running the training for the bug model. - run.main(run.parse_args(["--train", "--goal", "defect"])) - - # Test loading the trained model. - run.main(run.parse_args(["--goal", "defect"])) diff --git a/tests/test_trainer.py b/tests/test_trainer.py new file mode 100644 index 00000000..91149852 --- /dev/null +++ b/tests/test_trainer.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. + +from scripts import trainer + + +def test_trainer(): + trainer.Trainer().go(trainer.parse_args(["defect"]))