Refactored the weight loading from S3 into a cached singleton

This commit is contained in:
Victor Ng 2018-02-26 15:19:13 -05:00 коммит произвёл mlopatka
Родитель a20f4c1f43
Коммит 8e0f58b214
4 изменённых файлов: 53 добавлений и 12 удалений

Просмотреть файл

@ -3,7 +3,7 @@ from setuptools import find_packages, setup
setup(
name='mozilla-taar3',
use_scm_version=False,
version='0.0.22',
version='0.0.23',
setup_requires=['setuptools_scm', 'pytest-runner'],
tests_require=['pytest'],
include_package_data = True,

Просмотреть файл

@ -4,6 +4,7 @@ from .legacy_recommender import LegacyRecommender
from .similarity_recommender import SimilarityRecommender
from .recommendation_manager import RecommendationManager, RecommenderFactory
from .ensemble_recommender import EnsembleRecommender
from .ensemble_recommender import WeightCache
__all__ = [
@ -14,4 +15,5 @@ __all__ = [
'SimilarityRecommender',
'RecommendationManager',
'RecommenderFactory',
'WeightCache',
]

Просмотреть файл

@ -6,6 +6,8 @@ import logging
import itertools
from ..recommenders import utils
from .base_recommender import BaseRecommender
import threading
import time
S3_BUCKET = 'telemetry-parquet'
ENSEMBLE_WEIGHTS = 'taar/ensemble/ensemble_weight.json'
@ -13,6 +15,33 @@ ENSEMBLE_WEIGHTS = 'taar/ensemble/ensemble_weight.json'
logger = logging.getLogger(__name__)
class WeightCache:
def __init__(self):
self._lock = threading.RLock()
self._weights = None
self._expiry = None
def now(self):
return time.time()
def getWeights(self):
with self._lock:
now = self.now()
if self._expiry is not None:
if self._expiry < now:
# Cache is expired.
self._weights = None
# Push expiry to 5 minutes from now
self._expiry = now + 300
if self._weights is None:
tmp = utils.get_s3_json_content(S3_BUCKET, ENSEMBLE_WEIGHTS)
self._weights = tmp['ensemble_weights']
return self._weights
class EnsembleRecommender(BaseRecommender):
"""
The EnsembleRecommender is a collection of recommenders where the
@ -21,16 +50,12 @@ class EnsembleRecommender(BaseRecommender):
addons for users.
"""
def __init__(self, recommender_map):
tmp = utils.get_s3_json_content(S3_BUCKET, ENSEMBLE_WEIGHTS)
self._ensemble_weights = tmp['ensemble_weights']
# Copy the map of the recommenders
# TODO: verify that the recommender keys match what we've used
# in the ensemble training
self.RECOMMENDER_KEYS = ['legacy', 'collaborative', 'similarity', 'locale']
self._recommender_map = recommender_map
self._weight_cache = WeightCache()
def can_recommend(self, client_data, extra_data={}):
"""The ensemble recommender is always going to be
available if at least one recommender is available"""
@ -53,15 +78,16 @@ class EnsembleRecommender(BaseRecommender):
"""
flattened_results = []
ensemble_weights = self._weight_cache.getWeights()
for rkey in self.RECOMMENDER_KEYS:
recommender = self._recommender_map[rkey]
if recommender.can_recommend(client_data):
raw_results = recommender.recommend(client_data, limit, extra_data)
reweighted_results = []
for guid, weight in raw_results:
item = (guid, weight * self._ensemble_weights[rkey])
item = (guid, weight * ensemble_weights[rkey])
reweighted_results.append(item)
flattened_results.extend(reweighted_results)
@ -81,5 +107,9 @@ class EnsembleRecommender(BaseRecommender):
# Sort in reverse order (greatest weight to least)
ensemble_suggestions.sort(key=lambda x: -x[1])
results = ensemble_suggestions[:limit]
print(results, flattened_results)
log_data = (client_data['client_id'],
str(ensemble_weights),
str([r[0] for r in results]))
logger.info("client_id: [%s], ensemble_weight: [%s], guids: [%s]" % log_data)
return results

Просмотреть файл

@ -1,9 +1,18 @@
from taar.recommenders import EnsembleRecommender
from taar.recommenders import EnsembleRecommender, WeightCache
from .mocks import MockRecommenderFactory # noqa
from .mocks import mock_s3_ensemble_weights # noqa
def test_recommendations(mock_s3_ensemble_weights): # noqa
def test_weight_cache(mock_s3_ensemble_weights): # noqa
wc = WeightCache()
actual = wc.getWeights()
expected = {'legacy': 10000,
'collaborative': 1000,
'similarity': 100,
'locale': 10}
assert expected == actual
def test_recommendations(mock_s3_ensemble_weights): # noqa
EXPECTED_RESULTS = [('cde', 12000.0),
('bcd', 11000.0),
('abc', 10023.0),