зеркало из https://github.com/mozilla/taar.git
Refactored the weight loading from S3 into a cached singleton
This commit is contained in:
Родитель
a20f4c1f43
Коммит
8e0f58b214
2
setup.py
2
setup.py
|
@ -3,7 +3,7 @@ from setuptools import find_packages, setup
|
|||
setup(
|
||||
name='mozilla-taar3',
|
||||
use_scm_version=False,
|
||||
version='0.0.22',
|
||||
version='0.0.23',
|
||||
setup_requires=['setuptools_scm', 'pytest-runner'],
|
||||
tests_require=['pytest'],
|
||||
include_package_data = True,
|
||||
|
|
|
@ -4,6 +4,7 @@ from .legacy_recommender import LegacyRecommender
|
|||
from .similarity_recommender import SimilarityRecommender
|
||||
from .recommendation_manager import RecommendationManager, RecommenderFactory
|
||||
from .ensemble_recommender import EnsembleRecommender
|
||||
from .ensemble_recommender import WeightCache
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -14,4 +15,5 @@ __all__ = [
|
|||
'SimilarityRecommender',
|
||||
'RecommendationManager',
|
||||
'RecommenderFactory',
|
||||
'WeightCache',
|
||||
]
|
||||
|
|
|
@ -6,6 +6,8 @@ import logging
|
|||
import itertools
|
||||
from ..recommenders import utils
|
||||
from .base_recommender import BaseRecommender
|
||||
import threading
|
||||
import time
|
||||
|
||||
S3_BUCKET = 'telemetry-parquet'
|
||||
ENSEMBLE_WEIGHTS = 'taar/ensemble/ensemble_weight.json'
|
||||
|
@ -13,6 +15,33 @@ ENSEMBLE_WEIGHTS = 'taar/ensemble/ensemble_weight.json'
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeightCache:
|
||||
def __init__(self):
|
||||
self._lock = threading.RLock()
|
||||
|
||||
self._weights = None
|
||||
self._expiry = None
|
||||
|
||||
def now(self):
|
||||
return time.time()
|
||||
|
||||
def getWeights(self):
|
||||
with self._lock:
|
||||
now = self.now()
|
||||
if self._expiry is not None:
|
||||
if self._expiry < now:
|
||||
# Cache is expired.
|
||||
self._weights = None
|
||||
# Push expiry to 5 minutes from now
|
||||
self._expiry = now + 300
|
||||
|
||||
if self._weights is None:
|
||||
tmp = utils.get_s3_json_content(S3_BUCKET, ENSEMBLE_WEIGHTS)
|
||||
self._weights = tmp['ensemble_weights']
|
||||
|
||||
return self._weights
|
||||
|
||||
|
||||
class EnsembleRecommender(BaseRecommender):
|
||||
"""
|
||||
The EnsembleRecommender is a collection of recommenders where the
|
||||
|
@ -21,16 +50,12 @@ class EnsembleRecommender(BaseRecommender):
|
|||
addons for users.
|
||||
"""
|
||||
def __init__(self, recommender_map):
|
||||
tmp = utils.get_s3_json_content(S3_BUCKET, ENSEMBLE_WEIGHTS)
|
||||
self._ensemble_weights = tmp['ensemble_weights']
|
||||
|
||||
# Copy the map of the recommenders
|
||||
|
||||
# TODO: verify that the recommender keys match what we've used
|
||||
# in the ensemble training
|
||||
self.RECOMMENDER_KEYS = ['legacy', 'collaborative', 'similarity', 'locale']
|
||||
self._recommender_map = recommender_map
|
||||
|
||||
self._weight_cache = WeightCache()
|
||||
|
||||
def can_recommend(self, client_data, extra_data={}):
|
||||
"""The ensemble recommender is always going to be
|
||||
available if at least one recommender is available"""
|
||||
|
@ -53,15 +78,16 @@ class EnsembleRecommender(BaseRecommender):
|
|||
"""
|
||||
|
||||
flattened_results = []
|
||||
ensemble_weights = self._weight_cache.getWeights()
|
||||
|
||||
for rkey in self.RECOMMENDER_KEYS:
|
||||
recommender = self._recommender_map[rkey]
|
||||
|
||||
if recommender.can_recommend(client_data):
|
||||
raw_results = recommender.recommend(client_data, limit, extra_data)
|
||||
|
||||
reweighted_results = []
|
||||
for guid, weight in raw_results:
|
||||
item = (guid, weight * self._ensemble_weights[rkey])
|
||||
item = (guid, weight * ensemble_weights[rkey])
|
||||
reweighted_results.append(item)
|
||||
flattened_results.extend(reweighted_results)
|
||||
|
||||
|
@ -81,5 +107,9 @@ class EnsembleRecommender(BaseRecommender):
|
|||
# Sort in reverse order (greatest weight to least)
|
||||
ensemble_suggestions.sort(key=lambda x: -x[1])
|
||||
results = ensemble_suggestions[:limit]
|
||||
print(results, flattened_results)
|
||||
|
||||
log_data = (client_data['client_id'],
|
||||
str(ensemble_weights),
|
||||
str([r[0] for r in results]))
|
||||
logger.info("client_id: [%s], ensemble_weight: [%s], guids: [%s]" % log_data)
|
||||
return results
|
||||
|
|
|
@ -1,9 +1,18 @@
|
|||
from taar.recommenders import EnsembleRecommender
|
||||
from taar.recommenders import EnsembleRecommender, WeightCache
|
||||
from .mocks import MockRecommenderFactory # noqa
|
||||
from .mocks import mock_s3_ensemble_weights # noqa
|
||||
|
||||
def test_recommendations(mock_s3_ensemble_weights): # noqa
|
||||
def test_weight_cache(mock_s3_ensemble_weights): # noqa
|
||||
wc = WeightCache()
|
||||
actual = wc.getWeights()
|
||||
expected = {'legacy': 10000,
|
||||
'collaborative': 1000,
|
||||
'similarity': 100,
|
||||
'locale': 10}
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_recommendations(mock_s3_ensemble_weights): # noqa
|
||||
EXPECTED_RESULTS = [('cde', 12000.0),
|
||||
('bcd', 11000.0),
|
||||
('abc', 10023.0),
|
||||
|
|
Загрузка…
Ссылка в новой задаче