Get rid of utils.download_and_load_model function

This commit is contained in:
Marco Castelluccio 2020-11-19 23:34:27 +01:00
Родитель 2f6e108a37
Коммит bfc1fa3a85
6 изменённых файлов: 24 добавлений и 24 удалений

Просмотреть файл

@ -574,7 +574,7 @@ class Model:
return tracking_metrics
@staticmethod
def load(model_file_name):
def load(model_file_name: str) -> "Model":
return joblib.load(model_file_name)
def overwrite_classes(self, items, classes, probabilities):

Просмотреть файл

@ -33,7 +33,6 @@ from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OrdinalEncoder
from bugbug import get_bugbug_version
from bugbug.models import get_model_class
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@ -230,11 +229,6 @@ def download_model(model_name: str) -> str:
return path
def download_and_load_model(model_name):
path = download_model(model_name)
return get_model_class(model_name).load(path)
def zstd_compress(path: str) -> None:
cctx = zstandard.ZstdCompressor(threads=-1)
with open(path, "rb") as input_f:

Просмотреть файл

@ -4,7 +4,8 @@ import argparse
import sys
from logging import INFO, basicConfig, getLogger
from bugbug.utils import download_and_load_model
from bugbug.model import Model
from bugbug.utils import download_model
basicConfig(level=INFO)
logger = getLogger(__name__)
@ -13,7 +14,7 @@ logger = getLogger(__name__)
class ModelChecker:
def go(self, model_name: str) -> None:
# Load the model
model = download_and_load_model(model_name)
model = Model.load(download_model(model_name))
# Then call the check method of the model
success = model.check()

Просмотреть файл

@ -10,7 +10,7 @@ import re
import subprocess
from datetime import datetime
from logging import INFO, basicConfig, getLogger
from typing import Optional, Tuple
from typing import Optional, Tuple, cast
import dateutil.parser
import hglib
@ -26,9 +26,11 @@ from libmozdata.phabricator import PhabricatorAPI
from scipy.stats import spearmanr
from bugbug import db, repository, test_scheduling
from bugbug.model import Model
from bugbug.models.testfailure import TestFailureModel
from bugbug.utils import (
download_and_load_model,
download_check_etag,
download_model,
get_secret,
to_array,
zstd_decompress,
@ -137,7 +139,7 @@ class CommitClassifier(object):
self.model_name = model_name
self.repo_dir = repo_dir
self.model = download_and_load_model(model_name)
self.model = Model.load(download_model(model_name))
assert self.model is not None
self.git_repo_dir = git_repo_dir
@ -196,7 +198,9 @@ class CommitClassifier(object):
)
self.past_failures_data = test_scheduling.get_past_failures("label", True)
self.testfailure_model = download_and_load_model("testfailure")
self.testfailure_model = cast(
TestFailureModel, TestFailureModel.load(download_model("testfailure"))
)
assert self.testfailure_model is not None
def clone_git_repo(self, repo_url, repo_dir, rev="origin/branches/default/tip"):

Просмотреть файл

@ -10,17 +10,17 @@ import json
import logging
import os
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
import dateutil.parser
import requests
from tqdm import tqdm
from bugbug import bugzilla, db, phabricator, repository, test_scheduling
from bugbug.models.regressor import BUG_FIXING_COMMITS_DB
from bugbug.models.regressor import BUG_FIXING_COMMITS_DB, RegressorModel
from bugbug.utils import (
download_and_load_model,
download_check_etag,
download_model,
get_secret,
zstd_decompress,
)
@ -77,7 +77,9 @@ class LandingsRiskReportGenerator(object):
logger.info("Download commit classifications...")
assert db.download(BUG_FIXING_COMMITS_DB)
self.regressor_model = download_and_load_model("regressor")
self.regressor_model = cast(
RegressorModel, RegressorModel.load(download_model("regressor"))
)
bugzilla.set_token(get_secret("BUGZILLA_TOKEN"))
phabricator.set_api_key(

Просмотреть файл

@ -29,11 +29,7 @@ from bugbug.models.regressor import (
BUG_INTRODUCING_COMMITS_DB,
TOKENIZED_BUG_INTRODUCING_COMMITS_DB,
)
from bugbug.utils import (
ThreadPoolExecutorResult,
download_and_load_model,
zstd_compress,
)
from bugbug.utils import ThreadPoolExecutorResult, download_model, zstd_compress
basicConfig(level=INFO)
logger = getLogger(__name__)
@ -197,11 +193,14 @@ class RegressorFinder(object):
# TODO: Switch to the pure Defect model, as it's better in this case.
logger.info("Downloading defect/enhancement/task model...")
defect_model = cast(
DefectEnhancementTaskModel, download_and_load_model("defectenhancementtask")
DefectEnhancementTaskModel,
DefectEnhancementTaskModel.load(download_model("defectenhancementtask")),
)
logger.info("Downloading regression model...")
regression_model = cast(RegressionModel, download_and_load_model("regression"))
regression_model = cast(
RegressionModel, RegressionModel.load(download_model("regression"))
)
start_date = datetime.now() - RELATIVE_START_DATE
end_date = datetime.now() - RELATIVE_END_DATE