зеркало из https://github.com/mozilla/bugbug.git
Use db.download in trainer.py instead of manually reimplementing download and decompression (#739)
Fixes #733
This commit is contained in:
Родитель
331aa50f1f
Коммит
51e6a712ef
|
@ -4,44 +4,24 @@ import argparse
|
|||
import json
|
||||
import os
|
||||
from logging import INFO, basicConfig, getLogger
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
import zstandard
|
||||
|
||||
from bugbug import get_bugbug_version, model
|
||||
from bugbug import bugzilla, db, model, repository
|
||||
from bugbug.models import get_model_class
|
||||
from bugbug.utils import CustomJsonEncoder
|
||||
|
||||
basicConfig(level=INFO)
|
||||
logger = getLogger(__name__)
|
||||
|
||||
BASE_URL = "https://index.taskcluster.net/v1/task/project.relman.bugbug.data_{}.{}/artifacts/public"
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def decompress_file(self, path):
|
||||
dctx = zstandard.ZstdDecompressor()
|
||||
with open(f"{path}.zst", "rb") as input_f:
|
||||
with open(path, "wb") as output_f:
|
||||
dctx.copy_stream(input_f, output_f)
|
||||
assert os.path.exists(path), "Decompressed file exists"
|
||||
|
||||
def compress_file(self, path):
|
||||
cctx = zstandard.ZstdCompressor()
|
||||
with open(path, "rb") as input_f:
|
||||
with open(f"{path}.zst", "wb") as output_f:
|
||||
cctx.copy_stream(input_f, output_f)
|
||||
|
||||
def download_db(self, db_type):
|
||||
path = f"data/{db_type}.json"
|
||||
formatted_base_url = BASE_URL.format(db_type, f"v{get_bugbug_version()}")
|
||||
url = f"{formatted_base_url}/{db_type}.json.zst"
|
||||
logger.info(f"Downloading {db_type} database from {url} to {path}.zst")
|
||||
urlretrieve(url, f"{path}.zst")
|
||||
assert os.path.exists(f"{path}.zst"), "Downloaded file exists"
|
||||
logger.info(f"Decompressing {db_type} database")
|
||||
self.decompress_file(path)
|
||||
|
||||
def go(self, model_name):
|
||||
# Download datasets that were built by bugbug_data.
|
||||
os.makedirs("data", exist_ok=True)
|
||||
|
@ -54,10 +34,10 @@ class Trainer(object):
|
|||
or isinstance(model_obj, model.BugCoupleModel)
|
||||
or (hasattr(model_obj, "bug_data") and model_obj.bug_data)
|
||||
):
|
||||
self.download_db("bugs")
|
||||
db.download(bugzilla.BUGS_DB, force=True)
|
||||
|
||||
if isinstance(model_obj, model.CommitModel):
|
||||
self.download_db("commits")
|
||||
db.download(repository.COMMITS_DB, force=True)
|
||||
|
||||
logger.info(f"Training *{model_name}* model")
|
||||
metrics = model_obj.train()
|
||||
|
|
Загрузка…
Ссылка в новой задаче