Use db.download in trainer.py instead of manually reimplementing download and decompression (#739)

Fixes #733
This commit is contained in:
Anurag Aggarwal 2019-07-22 15:12:55 +05:30 коммит произвёл Marco
Родитель 331aa50f1f
Коммит 51e6a712ef
1 изменённых файлов: 3 добавлений и 23 удалений

Просмотреть файл

@ -4,44 +4,24 @@ import argparse
import json
import os
from logging import INFO, basicConfig, getLogger
from urllib.request import urlretrieve
import zstandard
from bugbug import get_bugbug_version, model
from bugbug import bugzilla, db, model, repository
from bugbug.models import get_model_class
from bugbug.utils import CustomJsonEncoder
basicConfig(level=INFO)
logger = getLogger(__name__)
BASE_URL = "https://index.taskcluster.net/v1/task/project.relman.bugbug.data_{}.{}/artifacts/public"
class Trainer(object):
def decompress_file(self, path):
dctx = zstandard.ZstdDecompressor()
with open(f"{path}.zst", "rb") as input_f:
with open(path, "wb") as output_f:
dctx.copy_stream(input_f, output_f)
assert os.path.exists(path), "Decompressed file exists"
def compress_file(self, path):
cctx = zstandard.ZstdCompressor()
with open(path, "rb") as input_f:
with open(f"{path}.zst", "wb") as output_f:
cctx.copy_stream(input_f, output_f)
def download_db(self, db_type):
path = f"data/{db_type}.json"
formatted_base_url = BASE_URL.format(db_type, f"v{get_bugbug_version()}")
url = f"{formatted_base_url}/{db_type}.json.zst"
logger.info(f"Downloading {db_type} database from {url} to {path}.zst")
urlretrieve(url, f"{path}.zst")
assert os.path.exists(f"{path}.zst"), "Downloaded file exists"
logger.info(f"Decompressing {db_type} database")
self.decompress_file(path)
def go(self, model_name):
# Download datasets that were built by bugbug_data.
os.makedirs("data", exist_ok=True)
@ -54,10 +34,10 @@ class Trainer(object):
or isinstance(model_obj, model.BugCoupleModel)
or (hasattr(model_obj, "bug_data") and model_obj.bug_data)
):
self.download_db("bugs")
db.download(bugzilla.BUGS_DB, force=True)
if isinstance(model_obj, model.CommitModel):
self.download_db("commits")
db.download(repository.COMMITS_DB, force=True)
logger.info(f"Training *{model_name}* model")
metrics = model_obj.train()