зеркало из https://github.com/mozilla/bugbug.git
Add extra functionality from run.py to the trainer script (#856)
Fixes #339
This commit is contained in:
Родитель
fc31172b84
Коммит
81c17ecbaf
|
@ -54,7 +54,7 @@ Every time you will try to commit, pre-commit will run checks on your files to m
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Run the `run.py` script to perform training / classification. The first time `run.py` is executed, the `--train` argument should be used to automatically download databases containing bugs and commits data (they will be downloaded in the data/ directory).
|
Run the `trainer.py` script with the command `python3 -c 'from scripts import trainer; trainer.main()'` to perform training.
|
||||||
|
|
||||||
### Running the repository mining script
|
### Running the repository mining script
|
||||||
|
|
||||||
|
|
87
run.py
87
run.py
|
@ -1,87 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
||||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
||||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from bugbug import bugzilla, db, repository
|
|
||||||
from bugbug.models import MODELS, get_model_class
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args(args):
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--lemmatization",
|
|
||||||
help="Perform lemmatization (using spaCy)",
|
|
||||||
action="store_true",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--training-set-size",
|
|
||||||
nargs="?",
|
|
||||||
default=14000,
|
|
||||||
type=int,
|
|
||||||
help="The size of the training set for the duplicate model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--disable-url-cleanup",
|
|
||||||
help="Don't cleanup urls when training the duplicate model",
|
|
||||||
dest="cleanup_urls",
|
|
||||||
default=True,
|
|
||||||
action="store_false",
|
|
||||||
)
|
|
||||||
parser.add_argument("--train", help="Perform training", action="store_true")
|
|
||||||
parser.add_argument(
|
|
||||||
"--goal", help="Goal of the classifier", choices=MODELS.keys(), default="defect"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--classifier",
|
|
||||||
help="Type of the classifier. Only used for component classification.",
|
|
||||||
choices=["default", "nn"],
|
|
||||||
default="default",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--historical",
|
|
||||||
help="""Analyze historical bugs. Only used for defect, bugtype,
|
|
||||||
defectenhancementtask and regression tasks.""",
|
|
||||||
action="store_true",
|
|
||||||
)
|
|
||||||
return parser.parse_args(args)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
if args.goal == "component":
|
|
||||||
if args.classifier == "default":
|
|
||||||
model_class_name = "component"
|
|
||||||
else:
|
|
||||||
model_class_name = "component_nn"
|
|
||||||
else:
|
|
||||||
model_class_name = args.goal
|
|
||||||
|
|
||||||
model_class = get_model_class(model_class_name)
|
|
||||||
|
|
||||||
if args.train:
|
|
||||||
db.download(bugzilla.BUGS_DB)
|
|
||||||
db.download(repository.COMMITS_DB)
|
|
||||||
|
|
||||||
historical_supported_tasks = [
|
|
||||||
"defect",
|
|
||||||
"bugtype",
|
|
||||||
"defectenhancementtask",
|
|
||||||
"regression",
|
|
||||||
]
|
|
||||||
|
|
||||||
if args.goal in historical_supported_tasks:
|
|
||||||
model = model_class(args.lemmatization, args.historical)
|
|
||||||
elif args.goal == "duplicate":
|
|
||||||
model = model_class(
|
|
||||||
args.training_set_size, args.lemmatization, args.cleanup_urls
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = model_class(args.lemmatization)
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main(parse_args(sys.argv[1:]))
|
|
|
@ -3,33 +3,58 @@
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from logging import INFO, basicConfig, getLogger
|
from logging import INFO, basicConfig, getLogger
|
||||||
|
|
||||||
from bugbug import bugzilla, db, model, repository
|
from bugbug import bugzilla, db, model, repository
|
||||||
from bugbug.models import get_model_class
|
from bugbug.models import get_model_class
|
||||||
from bugbug.utils import CustomJsonEncoder, zstd_compress
|
from bugbug.utils import CustomJsonEncoder, zstd_compress
|
||||||
|
|
||||||
|
MODELS_WITH_TYPE = ("component",)
|
||||||
|
HISTORICAL_SUPPORTED_TASKS = (
|
||||||
|
"defect",
|
||||||
|
"bugtype",
|
||||||
|
"defectenhancementtask",
|
||||||
|
"regression",
|
||||||
|
)
|
||||||
|
|
||||||
basicConfig(level=INFO)
|
basicConfig(level=INFO)
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Trainer(object):
|
class Trainer(object):
|
||||||
def go(self, model_name):
|
def go(self, args):
|
||||||
# Download datasets that were built by bugbug_data.
|
# Download datasets that were built by bugbug_data.
|
||||||
os.makedirs("data", exist_ok=True)
|
os.makedirs("data", exist_ok=True)
|
||||||
|
|
||||||
|
if args.classifier != "default":
|
||||||
|
assert (
|
||||||
|
args.model in MODELS_WITH_TYPE
|
||||||
|
), f"{args.classifier} is not a valid classifier type for {args.model}"
|
||||||
|
|
||||||
|
model_name = f"{args.model}_{args.classifier}"
|
||||||
|
else:
|
||||||
|
model_name = args.model
|
||||||
|
|
||||||
model_class = get_model_class(model_name)
|
model_class = get_model_class(model_name)
|
||||||
model_obj = model_class()
|
if args.model in HISTORICAL_SUPPORTED_TASKS:
|
||||||
|
model_obj = model_class(args.lemmatization, args.historical)
|
||||||
|
elif args.model == "duplicate":
|
||||||
|
model_obj = model_class(
|
||||||
|
args.training_set_size, args.lemmatization, args.cleanup_urls
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_obj = model_class(args.lemmatization)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(model_obj, model.BugModel)
|
isinstance(model_obj, model.BugModel)
|
||||||
or isinstance(model_obj, model.BugCoupleModel)
|
or isinstance(model_obj, model.BugCoupleModel)
|
||||||
or (hasattr(model_obj, "bug_data") and model_obj.bug_data)
|
or (hasattr(model_obj, "bug_data") and model_obj.bug_data)
|
||||||
):
|
):
|
||||||
db.download(bugzilla.BUGS_DB, force=True)
|
db.download(bugzilla.BUGS_DB)
|
||||||
|
|
||||||
if isinstance(model_obj, model.CommitModel):
|
if isinstance(model_obj, model.CommitModel):
|
||||||
db.download(repository.COMMITS_DB, force=True)
|
db.download(repository.COMMITS_DB)
|
||||||
|
|
||||||
logger.info(f"Training *{model_name}* model")
|
logger.info(f"Training *{model_name}* model")
|
||||||
metrics = model_obj.train()
|
metrics = model_obj.train()
|
||||||
|
@ -48,13 +73,47 @@ class Trainer(object):
|
||||||
logger.info(f"Model compressed")
|
logger.info(f"Model compressed")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def parse_args(args):
|
||||||
description = "Train the models"
|
description = "Train the models"
|
||||||
parser = argparse.ArgumentParser(description=description)
|
parser = argparse.ArgumentParser(description=description)
|
||||||
|
|
||||||
parser.add_argument("model", help="Which model to train.")
|
parser.add_argument("model", help="Which model to train.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--lemmatization",
|
||||||
|
help="Perform lemmatization (using spaCy)",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--training-set-size",
|
||||||
|
nargs="?",
|
||||||
|
default=14000,
|
||||||
|
type=int,
|
||||||
|
help="The size of the training set for the duplicate model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-url-cleanup",
|
||||||
|
help="Don't cleanup urls when training the duplicate model",
|
||||||
|
dest="cleanup_urls",
|
||||||
|
default=True,
|
||||||
|
action="store_false",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--classifier",
|
||||||
|
help="Type of the classifier. Only used for component classification.",
|
||||||
|
choices=["default", "nn"],
|
||||||
|
default="default",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--historical",
|
||||||
|
help="""Analyze historical bugs. Only used for defect, bugtype,
|
||||||
|
defectenhancementtask and regression tasks.""",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
return parser.parse_args(args)
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
def main():
|
||||||
|
args = parse_args(sys.argv[1:])
|
||||||
|
|
||||||
retriever = Trainer()
|
retriever = Trainer()
|
||||||
retriever.go(args.model)
|
retriever.go(args)
|
||||||
|
|
|
@ -1,14 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
||||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
||||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
||||||
|
|
||||||
import run
|
|
||||||
|
|
||||||
|
|
||||||
def test_run():
|
|
||||||
# Test running the training for the bug model.
|
|
||||||
run.main(run.parse_args(["--train", "--goal", "defect"]))
|
|
||||||
|
|
||||||
# Test loading the trained model.
|
|
||||||
run.main(run.parse_args(["--goal", "defect"]))
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
|
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
from scripts import trainer
|
||||||
|
|
||||||
|
|
||||||
|
def test_trainer():
|
||||||
|
trainer.Trainer().go(trainer.parse_args(["defect"]))
|
Загрузка…
Ссылка в новой задаче