зеркало из https://github.com/mozilla/taar.git
Added ensemble test through the RecommendationManager
This commit is contained in:
Родитель
8e98f40bbb
Коммит
a1bb4b36db
|
@ -3,10 +3,12 @@ from .locale_recommender import LocaleRecommender
|
|||
from .legacy_recommender import LegacyRecommender
|
||||
from .similarity_recommender import SimilarityRecommender
|
||||
from .recommendation_manager import RecommendationManager, RecommenderFactory
|
||||
from .ensemble_recommender import EnsembleRecommender
|
||||
|
||||
|
||||
__all__ = [
|
||||
'CollaborativeRecommender',
|
||||
'EnsembleRecommender',
|
||||
'LegacyRecommender',
|
||||
'LocaleRecommender',
|
||||
'SimilarityRecommender',
|
||||
|
|
|
@ -50,7 +50,7 @@ class RecommendationManager:
|
|||
self.linear_recommenders.append(recommender)
|
||||
self._recommender_map[rkey] = recommender
|
||||
|
||||
self._recommender_map['ensemble'] = EnsembleRecommender(self.linear_recommenders)
|
||||
self._recommender_map['ensemble'] = EnsembleRecommender(self._recommender_map)
|
||||
|
||||
def recommend(self, client_id, limit, extra_data={}):
|
||||
"""Return recommendations for the given client.
|
||||
|
|
|
@ -8,7 +8,8 @@ import pytest
|
|||
from moto import mock_s3
|
||||
from taar.recommenders.ensemble_recommender import S3_BUCKET
|
||||
from taar.recommenders.ensemble_recommender import ENSEMBLE_WEIGHTS
|
||||
from taar.recommenders.ensemble_recommender import EnsembleRecommender
|
||||
from taar.recommenders import EnsembleRecommender
|
||||
from taar.recommenders import RecommendationManager
|
||||
from .mocks import MockRecommenderFactory
|
||||
|
||||
|
||||
|
@ -48,3 +49,27 @@ def test_recommendations(mock_s3_ensemble_weights):
|
|||
recommendation_list = r.recommend(client, 10)
|
||||
assert isinstance(recommendation_list, list)
|
||||
assert recommendation_list == EXPECTED_RESULTS
|
||||
|
||||
|
||||
def test_recommendations_via_manager(mock_s3_ensemble_weights):
|
||||
EXPECTED_RESULTS = [('cde', 12000.0),
|
||||
('bcd', 11000.0),
|
||||
('abc', 10023.0),
|
||||
('ghi', 3430.0),
|
||||
('def', 3320.0),
|
||||
('ijk', 3200.0),
|
||||
('hij', 3100.0),
|
||||
('lmn', 420.0),
|
||||
('klm', 409.99999999999994),
|
||||
('jkl', 400.0)]
|
||||
|
||||
factory = MockRecommenderFactory()
|
||||
|
||||
class MockProfileFetcher:
|
||||
def get(self, client_id):
|
||||
return {}
|
||||
|
||||
manager = RecommendationManager(factory, MockProfileFetcher())
|
||||
recommendation_list = manager.recommend('some_ignored_id', 10, extra_data={'branch': 'ensemble'})
|
||||
assert isinstance(recommendation_list, list)
|
||||
assert recommendation_list == EXPECTED_RESULTS
|
||||
|
|
Загрузка…
Ссылка в новой задаче