allow or queries, fix len and slices

This commit is contained in:
Jeff Balogh 2011-06-10 14:42:26 -07:00
Родитель 3c3738b992
Коммит 7bfe9b9312
2 изменённых файлов: 39 добавлений и 5 удалений

Просмотреть файл

@ -1,4 +1,5 @@
import logging
import pprint
from django.conf import settings
@ -14,6 +15,7 @@ class ES(object):
self.type = type_
self.filters = {}
self.in_ = {}
self.or_ = {}
self.queries = {}
self.fields = ['id']
self.ordering = []
@ -25,6 +27,7 @@ class ES(object):
new = self.__class__(self.type)
new.filters = dict(self.filters)
new.in_ = dict(self.in_)
new.or_ = list(self.or_)
new.queries = dict(self.queries)
new.fields = list(self.fields)
new.ordering = list(self.ordering)
@ -60,17 +63,24 @@ class ES(object):
new.filters[key] = value
return new
# This is a lame hack.
def filter_or(self, **kw):
new = self._clone()
new.or_.append(kw)
return new
def count(self):
num = self._do_search().count
self._results_cache = None
return num
__len__ = count
def __len__(self):
return len(self._do_search())
def __getitem__(self, k):
# TODO: validate numbers and ranges
if isinstance(k, slice):
self.start, self.stop = k.start, k.stop
self.start, self.stop = k.start or 0, k.stop
return self
else:
self.start, self.stop = k, k + 1
@ -81,13 +91,18 @@ class ES(object):
if self.queries:
qs['query'] = {'term': self.queries}
if len(self.filters) + len(self.in_) > 1:
if len(self.filters) + len(self.in_) + len(self.or_) > 1:
qs['filter'] = {'and': []}
and_ = qs['filter']['and']
for key, value in self.filters.items():
and_.append({'term': {key: value}})
for key, value in self.in_.items():
and_.append({'in': {key: value}})
for dict_ in self.or_:
or_ = []
for key, value in dict_.items():
or_.append({'term': {key: value}})
and_.append({'or': or_})
elif self.filters:
qs['filter'] = {'term': self.filters}
elif self.in_:
@ -108,7 +123,7 @@ class ES(object):
if not self._results_cache:
qs = self._build_query()
es = elasticutils.get_es()
log.debug(qs)
log.debug(pprint.pformat(qs))
hits = es.search(qs, settings.ES_INDEX, self.type._meta.app_label)
self._results_cache = SearchResults(self.type, hits)
return self._results_cache
@ -125,7 +140,7 @@ class SearchResults(object):
statsd.timing('search', self.took)
log.debug('Query took %dms.' % self.took)
self.count = results['hits']['total']
self.ids = [r['fields']['id'] for r in results['hits']['hits']]
self.ids = [int(r['_id']) for r in results['hits']['hits']]
self.objects = self.type.objects.filter(id__in=self.ids)
self.results = results

Просмотреть файл

@ -64,6 +64,20 @@ class TestES(amo.tests.ESTestCase):
'from': 5,
'size': 7})
def test_or(self):
qs = Addon.search().filter(type=1).filter_or(status=1, app=2)
eq_(qs._build_query(), {'fields': ['id'],
'filter': {'and': [
{'term': {'type': 1}},
{'or': [{'term': {'status': 1}},
{'term': {'app': 2}}]},
]}})
def test_slice_stop(self):
qs = Addon.search()[:6]
eq_(qs._build_query(), {'fields': ['id'],
'size': 6})
def test_getitem(self):
addons = list(Addon.search())
eq_(addons[0], Addon.search()[0])
@ -75,3 +89,8 @@ class TestES(amo.tests.ESTestCase):
def test_count(self):
eq_(Addon.search().count(), 6)
def test_len(self):
qs = Addon.search()
qs._results_cache = [1]
eq_(len(qs), 1)