зеркало из https://github.com/mozilla/bugbug.git
Get rid of utils.download_and_load_model function
This commit is contained in:
Родитель
2f6e108a37
Коммит
bfc1fa3a85
|
@ -574,7 +574,7 @@ class Model:
|
|||
return tracking_metrics
|
||||
|
||||
@staticmethod
|
||||
def load(model_file_name):
|
||||
def load(model_file_name: str) -> "Model":
|
||||
return joblib.load(model_file_name)
|
||||
|
||||
def overwrite_classes(self, items, classes, probabilities):
|
||||
|
|
|
@ -33,7 +33,6 @@ from sklearn.compose import ColumnTransformer
|
|||
from sklearn.preprocessing import OrdinalEncoder
|
||||
|
||||
from bugbug import get_bugbug_version
|
||||
from bugbug.models import get_model_class
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -230,11 +229,6 @@ def download_model(model_name: str) -> str:
|
|||
return path
|
||||
|
||||
|
||||
def download_and_load_model(model_name):
|
||||
path = download_model(model_name)
|
||||
return get_model_class(model_name).load(path)
|
||||
|
||||
|
||||
def zstd_compress(path: str) -> None:
|
||||
cctx = zstandard.ZstdCompressor(threads=-1)
|
||||
with open(path, "rb") as input_f:
|
||||
|
|
|
@ -4,7 +4,8 @@ import argparse
|
|||
import sys
|
||||
from logging import INFO, basicConfig, getLogger
|
||||
|
||||
from bugbug.utils import download_and_load_model
|
||||
from bugbug.model import Model
|
||||
from bugbug.utils import download_model
|
||||
|
||||
basicConfig(level=INFO)
|
||||
logger = getLogger(__name__)
|
||||
|
@ -13,7 +14,7 @@ logger = getLogger(__name__)
|
|||
class ModelChecker:
|
||||
def go(self, model_name: str) -> None:
|
||||
# Load the model
|
||||
model = download_and_load_model(model_name)
|
||||
model = Model.load(download_model(model_name))
|
||||
|
||||
# Then call the check method of the model
|
||||
success = model.check()
|
||||
|
|
|
@ -10,7 +10,7 @@ import re
|
|||
import subprocess
|
||||
from datetime import datetime
|
||||
from logging import INFO, basicConfig, getLogger
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, cast
|
||||
|
||||
import dateutil.parser
|
||||
import hglib
|
||||
|
@ -26,9 +26,11 @@ from libmozdata.phabricator import PhabricatorAPI
|
|||
from scipy.stats import spearmanr
|
||||
|
||||
from bugbug import db, repository, test_scheduling
|
||||
from bugbug.model import Model
|
||||
from bugbug.models.testfailure import TestFailureModel
|
||||
from bugbug.utils import (
|
||||
download_and_load_model,
|
||||
download_check_etag,
|
||||
download_model,
|
||||
get_secret,
|
||||
to_array,
|
||||
zstd_decompress,
|
||||
|
@ -137,7 +139,7 @@ class CommitClassifier(object):
|
|||
self.model_name = model_name
|
||||
self.repo_dir = repo_dir
|
||||
|
||||
self.model = download_and_load_model(model_name)
|
||||
self.model = Model.load(download_model(model_name))
|
||||
assert self.model is not None
|
||||
|
||||
self.git_repo_dir = git_repo_dir
|
||||
|
@ -196,7 +198,9 @@ class CommitClassifier(object):
|
|||
)
|
||||
self.past_failures_data = test_scheduling.get_past_failures("label", True)
|
||||
|
||||
self.testfailure_model = download_and_load_model("testfailure")
|
||||
self.testfailure_model = cast(
|
||||
TestFailureModel, TestFailureModel.load(download_model("testfailure"))
|
||||
)
|
||||
assert self.testfailure_model is not None
|
||||
|
||||
def clone_git_repo(self, repo_url, repo_dir, rev="origin/branches/default/tip"):
|
||||
|
|
|
@ -10,17 +10,17 @@ import json
|
|||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import dateutil.parser
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from bugbug import bugzilla, db, phabricator, repository, test_scheduling
|
||||
from bugbug.models.regressor import BUG_FIXING_COMMITS_DB
|
||||
from bugbug.models.regressor import BUG_FIXING_COMMITS_DB, RegressorModel
|
||||
from bugbug.utils import (
|
||||
download_and_load_model,
|
||||
download_check_etag,
|
||||
download_model,
|
||||
get_secret,
|
||||
zstd_decompress,
|
||||
)
|
||||
|
@ -77,7 +77,9 @@ class LandingsRiskReportGenerator(object):
|
|||
logger.info("Download commit classifications...")
|
||||
assert db.download(BUG_FIXING_COMMITS_DB)
|
||||
|
||||
self.regressor_model = download_and_load_model("regressor")
|
||||
self.regressor_model = cast(
|
||||
RegressorModel, RegressorModel.load(download_model("regressor"))
|
||||
)
|
||||
|
||||
bugzilla.set_token(get_secret("BUGZILLA_TOKEN"))
|
||||
phabricator.set_api_key(
|
||||
|
|
|
@ -29,11 +29,7 @@ from bugbug.models.regressor import (
|
|||
BUG_INTRODUCING_COMMITS_DB,
|
||||
TOKENIZED_BUG_INTRODUCING_COMMITS_DB,
|
||||
)
|
||||
from bugbug.utils import (
|
||||
ThreadPoolExecutorResult,
|
||||
download_and_load_model,
|
||||
zstd_compress,
|
||||
)
|
||||
from bugbug.utils import ThreadPoolExecutorResult, download_model, zstd_compress
|
||||
|
||||
basicConfig(level=INFO)
|
||||
logger = getLogger(__name__)
|
||||
|
@ -197,11 +193,14 @@ class RegressorFinder(object):
|
|||
# TODO: Switch to the pure Defect model, as it's better in this case.
|
||||
logger.info("Downloading defect/enhancement/task model...")
|
||||
defect_model = cast(
|
||||
DefectEnhancementTaskModel, download_and_load_model("defectenhancementtask")
|
||||
DefectEnhancementTaskModel,
|
||||
DefectEnhancementTaskModel.load(download_model("defectenhancementtask")),
|
||||
)
|
||||
|
||||
logger.info("Downloading regression model...")
|
||||
regression_model = cast(RegressionModel, download_and_load_model("regression"))
|
||||
regression_model = cast(
|
||||
RegressionModel, RegressionModel.load(download_model("regression"))
|
||||
)
|
||||
|
||||
start_date = datetime.now() - RELATIVE_START_DATE
|
||||
end_date = datetime.now() - RELATIVE_END_DATE
|
||||
|
|
Загрузка…
Ссылка в новой задаче