From 51e6a712ef5c379e73ef30ff3c6c907b3faca42e Mon Sep 17 00:00:00 2001 From: Anurag Aggarwal Date: Mon, 22 Jul 2019 15:12:55 +0530 Subject: [PATCH] Use db.download in trainer.py instead of manually reimplementing download and decompression (#739) Fixes #733 --- scripts/trainer.py | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/scripts/trainer.py b/scripts/trainer.py index 196f33d7..d79feb32 100644 --- a/scripts/trainer.py +++ b/scripts/trainer.py @@ -4,44 +4,24 @@ import argparse import json import os from logging import INFO, basicConfig, getLogger -from urllib.request import urlretrieve import zstandard -from bugbug import get_bugbug_version, model +from bugbug import bugzilla, db, model, repository from bugbug.models import get_model_class from bugbug.utils import CustomJsonEncoder basicConfig(level=INFO) logger = getLogger(__name__) -BASE_URL = "https://index.taskcluster.net/v1/task/project.relman.bugbug.data_{}.{}/artifacts/public" - class Trainer(object): - def decompress_file(self, path): - dctx = zstandard.ZstdDecompressor() - with open(f"{path}.zst", "rb") as input_f: - with open(path, "wb") as output_f: - dctx.copy_stream(input_f, output_f) - assert os.path.exists(path), "Decompressed file exists" - def compress_file(self, path): cctx = zstandard.ZstdCompressor() with open(path, "rb") as input_f: with open(f"{path}.zst", "wb") as output_f: cctx.copy_stream(input_f, output_f) - def download_db(self, db_type): - path = f"data/{db_type}.json" - formatted_base_url = BASE_URL.format(db_type, f"v{get_bugbug_version()}") - url = f"{formatted_base_url}/{db_type}.json.zst" - logger.info(f"Downloading {db_type} database from {url} to {path}.zst") - urlretrieve(url, f"{path}.zst") - assert os.path.exists(f"{path}.zst"), "Downloaded file exists" - logger.info(f"Decompressing {db_type} database") - self.decompress_file(path) - def go(self, model_name): # Download datasets that were built by bugbug_data. os.makedirs("data", exist_ok=True) @@ -54,10 +34,10 @@ class Trainer(object): or isinstance(model_obj, model.BugCoupleModel) or (hasattr(model_obj, "bug_data") and model_obj.bug_data) ): - self.download_db("bugs") + db.download(bugzilla.BUGS_DB, force=True) if isinstance(model_obj, model.CommitModel): - self.download_db("commits") + db.download(repository.COMMITS_DB, force=True) logger.info(f"Training *{model_name}* model") metrics = model_obj.train()