bugbug/scripts/trainer.py

146 строки
4.6 KiB
Python

# -*- coding: utf-8 -*-
import argparse
import inspect
import json
import os
import sys
from logging import INFO, basicConfig, getLogger
from bugbug import db
from bugbug.models import MODELS, get_model_class
from bugbug.utils import CustomJsonEncoder, create_tar_zst, zstd_compress
basicConfig(level=INFO)
logger = getLogger(__name__)
class Trainer(object):
def go(self, args):
# Download datasets that were built by bugbug_data.
os.makedirs("data", exist_ok=True)
model_name = args.model
model_class = get_model_class(model_name)
parameter_names = set(inspect.signature(model_class.__init__).parameters)
parameters = {
key: value for key, value in vars(args).items() if key in parameter_names
}
model_obj = model_class(**parameters)
if args.download_db:
for required_db in model_obj.training_dbs:
assert db.download(required_db)
if args.download_eval:
model_obj.download_eval_dbs()
else:
logger.info("Skipping download of the databases")
logger.info("Training *%s* model", model_name)
metrics = model_obj.train(limit=args.limit)
# Save the metrics as a file that can be uploaded as an artifact.
metric_file_path = "metrics.json"
with open(metric_file_path, "w") as metric_file:
json.dump(metrics, metric_file, cls=CustomJsonEncoder)
logger.info("Training done")
model_directory = f"{model_name}model"
assert os.path.exists(model_directory)
create_tar_zst(f"{model_directory}.tar.zst")
logger.info("Model compressed")
if model_obj.store_dataset:
assert os.path.exists(f"{model_name}model_data_X")
zstd_compress(f"{model_name}model_data_X")
assert os.path.exists(f"{model_name}model_data_y")
zstd_compress(f"{model_name}model_data_y")
def parse_args(args):
description = "Train the models"
main_parser = argparse.ArgumentParser(description=description)
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument(
"--limit",
type=int,
help="Only train on a subset of the data, used mainly for integrations tests",
)
parser.add_argument(
"--no-download",
action="store_false",
dest="download_db",
help="Do not download databases, uses whatever is on disk",
)
parser.add_argument(
"--download-eval",
action="store_true",
dest="download_eval",
help="Download databases and database support files required at runtime (e.g. if the model performs custom evaluations)",
)
parser.add_argument(
"--lemmatization",
help="Perform lemmatization (using spaCy)",
action="store_true",
)
subparsers = main_parser.add_subparsers(title="model", dest="model", required=True)
for model_name in MODELS:
subparser = subparsers.add_parser(
model_name, parents=[parser], help=f"Train {model_name} model"
)
try:
model_class_init = get_model_class(model_name).__init__
except ImportError:
continue
for parameter in inspect.signature(model_class_init).parameters.values():
if parameter.name == "self":
continue
# Skip parameters handled by the base class (TODO: add them to the common argparser and skip them automatically without hardcoding by inspecting the base class)
if parameter.name == "lemmatization":
continue
parameter_type = parameter.annotation
if parameter_type == inspect._empty:
parameter_type = type(parameter.default)
assert parameter_type is not None
if parameter_type == bool:
subparser.add_argument(
f"--{parameter.name}"
if parameter.default is False
else f"--no-{parameter.name}",
action="store_true"
if parameter.default is False
else "store_false",
dest=parameter.name,
)
else:
subparser.add_argument(
f"--{parameter.name}",
default=parameter.default,
dest=parameter.name,
type=int,
)
return main_parser.parse_args(args)
def main():
args = parse_args(sys.argv[1:])
retriever = Trainer()
retriever.go(args)
if __name__ == "__main__":
main()