зеркало из https://github.com/mozilla/bugbug.git
Add extra functionality from run.py to the trainer script (#856)
Fixes #339
This commit is contained in:
Родитель
fc31172b84
Коммит
81c17ecbaf
|
@ -54,7 +54,7 @@ Every time you will try to commit, pre-commit will run checks on your files to m
|
|||
|
||||
## Usage
|
||||
|
||||
Run the `run.py` script to perform training / classification. The first time `run.py` is executed, the `--train` argument should be used to automatically download databases containing bugs and commits data (they will be downloaded in the data/ directory).
|
||||
Run the `trainer.py` script with the command `python3 -c 'from scripts import trainer; trainer.main()'` to perform training.
|
||||
|
||||
### Running the repository mining script
|
||||
|
||||
|
|
87
run.py
87
run.py
|
@ -1,87 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from bugbug import bugzilla, db, repository
|
||||
from bugbug.models import MODELS, get_model_class
|
||||
|
||||
|
||||
def parse_args(args):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lemmatization",
|
||||
help="Perform lemmatization (using spaCy)",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training-set-size",
|
||||
nargs="?",
|
||||
default=14000,
|
||||
type=int,
|
||||
help="The size of the training set for the duplicate model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-url-cleanup",
|
||||
help="Don't cleanup urls when training the duplicate model",
|
||||
dest="cleanup_urls",
|
||||
default=True,
|
||||
action="store_false",
|
||||
)
|
||||
parser.add_argument("--train", help="Perform training", action="store_true")
|
||||
parser.add_argument(
|
||||
"--goal", help="Goal of the classifier", choices=MODELS.keys(), default="defect"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--classifier",
|
||||
help="Type of the classifier. Only used for component classification.",
|
||||
choices=["default", "nn"],
|
||||
default="default",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--historical",
|
||||
help="""Analyze historical bugs. Only used for defect, bugtype,
|
||||
defectenhancementtask and regression tasks.""",
|
||||
action="store_true",
|
||||
)
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.goal == "component":
|
||||
if args.classifier == "default":
|
||||
model_class_name = "component"
|
||||
else:
|
||||
model_class_name = "component_nn"
|
||||
else:
|
||||
model_class_name = args.goal
|
||||
|
||||
model_class = get_model_class(model_class_name)
|
||||
|
||||
if args.train:
|
||||
db.download(bugzilla.BUGS_DB)
|
||||
db.download(repository.COMMITS_DB)
|
||||
|
||||
historical_supported_tasks = [
|
||||
"defect",
|
||||
"bugtype",
|
||||
"defectenhancementtask",
|
||||
"regression",
|
||||
]
|
||||
|
||||
if args.goal in historical_supported_tasks:
|
||||
model = model_class(args.lemmatization, args.historical)
|
||||
elif args.goal == "duplicate":
|
||||
model = model_class(
|
||||
args.training_set_size, args.lemmatization, args.cleanup_urls
|
||||
)
|
||||
else:
|
||||
model = model_class(args.lemmatization)
|
||||
model.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(parse_args(sys.argv[1:]))
|
|
@ -3,33 +3,58 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from logging import INFO, basicConfig, getLogger
|
||||
|
||||
from bugbug import bugzilla, db, model, repository
|
||||
from bugbug.models import get_model_class
|
||||
from bugbug.utils import CustomJsonEncoder, zstd_compress
|
||||
|
||||
MODELS_WITH_TYPE = ("component",)
|
||||
HISTORICAL_SUPPORTED_TASKS = (
|
||||
"defect",
|
||||
"bugtype",
|
||||
"defectenhancementtask",
|
||||
"regression",
|
||||
)
|
||||
|
||||
basicConfig(level=INFO)
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def go(self, model_name):
|
||||
def go(self, args):
|
||||
# Download datasets that were built by bugbug_data.
|
||||
os.makedirs("data", exist_ok=True)
|
||||
|
||||
if args.classifier != "default":
|
||||
assert (
|
||||
args.model in MODELS_WITH_TYPE
|
||||
), f"{args.classifier} is not a valid classifier type for {args.model}"
|
||||
|
||||
model_name = f"{args.model}_{args.classifier}"
|
||||
else:
|
||||
model_name = args.model
|
||||
|
||||
model_class = get_model_class(model_name)
|
||||
model_obj = model_class()
|
||||
if args.model in HISTORICAL_SUPPORTED_TASKS:
|
||||
model_obj = model_class(args.lemmatization, args.historical)
|
||||
elif args.model == "duplicate":
|
||||
model_obj = model_class(
|
||||
args.training_set_size, args.lemmatization, args.cleanup_urls
|
||||
)
|
||||
else:
|
||||
model_obj = model_class(args.lemmatization)
|
||||
|
||||
if (
|
||||
isinstance(model_obj, model.BugModel)
|
||||
or isinstance(model_obj, model.BugCoupleModel)
|
||||
or (hasattr(model_obj, "bug_data") and model_obj.bug_data)
|
||||
):
|
||||
db.download(bugzilla.BUGS_DB, force=True)
|
||||
db.download(bugzilla.BUGS_DB)
|
||||
|
||||
if isinstance(model_obj, model.CommitModel):
|
||||
db.download(repository.COMMITS_DB, force=True)
|
||||
db.download(repository.COMMITS_DB)
|
||||
|
||||
logger.info(f"Training *{model_name}* model")
|
||||
metrics = model_obj.train()
|
||||
|
@ -48,13 +73,47 @@ class Trainer(object):
|
|||
logger.info(f"Model compressed")
|
||||
|
||||
|
||||
def main():
|
||||
def parse_args(args):
|
||||
description = "Train the models"
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
|
||||
parser.add_argument("model", help="Which model to train.")
|
||||
parser.add_argument(
|
||||
"--lemmatization",
|
||||
help="Perform lemmatization (using spaCy)",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training-set-size",
|
||||
nargs="?",
|
||||
default=14000,
|
||||
type=int,
|
||||
help="The size of the training set for the duplicate model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-url-cleanup",
|
||||
help="Don't cleanup urls when training the duplicate model",
|
||||
dest="cleanup_urls",
|
||||
default=True,
|
||||
action="store_false",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--classifier",
|
||||
help="Type of the classifier. Only used for component classification.",
|
||||
choices=["default", "nn"],
|
||||
default="default",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--historical",
|
||||
help="""Analyze historical bugs. Only used for defect, bugtype,
|
||||
defectenhancementtask and regression tasks.""",
|
||||
action="store_true",
|
||||
)
|
||||
return parser.parse_args(args)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def main():
|
||||
args = parse_args(sys.argv[1:])
|
||||
|
||||
retriever = Trainer()
|
||||
retriever.go(args.model)
|
||||
retriever.go(args)
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import run
|
||||
|
||||
|
||||
def test_run():
|
||||
# Test running the training for the bug model.
|
||||
run.main(run.parse_args(["--train", "--goal", "defect"]))
|
||||
|
||||
# Test loading the trained model.
|
||||
run.main(run.parse_args(["--goal", "defect"]))
|
|
@ -0,0 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
from scripts import trainer
|
||||
|
||||
|
||||
def test_trainer():
|
||||
trainer.Trainer().go(trainer.parse_args(["defect"]))
|
Загрузка…
Ссылка в новой задаче