Add unit test support for query should param

Signed-off-by: Brandon Myers <bmyers@mozilla.com>
This commit is contained in:
Brandon Myers 2016-11-01 18:57:39 -05:00
Родитель b8f9aa8d10
Коммит ccebf7344d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 8AA79AD83045BBC7
5 изменённых файлов: 88 добавлений и 62 удалений

Просмотреть файл

@ -9,6 +9,12 @@ from datetime import datetime
from datetime import timedelta
def MultiMatch(key, value):
if pyes_enabled.pyes_on is True:
return pyes.QueryFilter(pyes.MatchQuery(key, value, 'boolean'))
return Q('multi_match', query=key, fields=value)
def ExistsMatch(field_name):
if pyes_enabled.pyes_on is True:
return pyes.ExistsFilter(field_name)
@ -197,9 +203,9 @@ class SearchQuery():
def add_aggregation(self, input_obj):
self.append_to_array(self.aggregation, input_obj)
def execute(self, elasticsearch_client, indices=['events', 'events-previous']):
if self.must == [] and self.must_not == [] and self.should == []:
raise AttributeError('Must define a must, must_not, or should query')
def execute(self, elasticsearch_client, indices=['events', 'events-previous'], size=1000):
if self.must == [] and self.must_not == [] and self.should == [] and self.aggregation == []:
raise AttributeError('Must define a must, must_not, should query, or aggregation')
if self.date_timedelta:
end_date = toUTC(datetime.now())
@ -210,21 +216,15 @@ class SearchQuery():
search_query = None
if pyes_enabled.pyes_on is True:
search_query = pyes.ConstantScoreQuery(pyes.MatchAllQuery())
search_query.filters.append(BooleanMatch(
must=self.must, should=self.should, must_not=self.must_not))
search_query.filters.append(BooleanMatch(must=self.must, should=self.should, must_not=self.must_not))
else:
search_query = BooleanMatch(
must=self.must, must_not=self.must_not, should=self.should)
# Remove try catch statement when we remove pyes_on
# results = []
try:
if len(self.aggregation) == 0:
results = elasticsearch_client.search(search_query, indices)
else:
results = elasticsearch_client.aggregated_search(
search_query, indices, self.aggregation)
except RuntimeError:
results = []
results = []
if len(self.aggregation) == 0:
results = elasticsearch_client.search(search_query, indices, size)
else:
results = elasticsearch_client.aggregated_search(search_query, indices, self.aggregation, size)
return results

Просмотреть файл

@ -5,7 +5,6 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
from unit_test_suite import UnitTestSuite
import random
import copy
import re
@ -146,35 +145,6 @@ class AlertTestSuite(UnitTestSuite):
else:
assert len(alert_task.alert_ids) is 0
def random_ip(self):
return str(random.randint(1, 255)) + "." + str(random.randint(1, 255)) + "." + str(random.randint(1, 255)) + "." + str(random.randint(1, 255))
def generate_default_event(self):
current_timestamp = UnitTestSuite.current_timestamp_lambda()
source_ip = self.random_ip()
event = {
"_index": "events",
"_type": "event",
"_source": {
"category": "excategory",
"utctimestamp": current_timestamp,
"hostname": "exhostname",
"severity": "NOTICE",
"source": "exsource",
"summary": "Example summary",
"tags": ['tag1', 'tag2'],
"details": {
"sourceipaddress": source_ip,
"hostname": "exhostname"
}
}
}
return event
@staticmethod
def copy(obj):
return copy.deepcopy(obj)

Просмотреть файл

@ -35,9 +35,3 @@ class QueryTestSuite(UnitTestSuite):
search_query.add_must_not(query)
query_result = search_query.execute(self.es_client)
self.verify_test(query_result, self.positive_test is False)
# Testing should
# todo: figure out a way to automagically test 'should'

Просмотреть файл

@ -21,6 +21,16 @@ class SearchQueryUnitTest(UnitTestSuite):
assert self.query.should == []
assert self.query.aggregation == []
def populate_example_event(self):
event = {
'summary': 'Test Summary',
'note': 'Example note',
'details': {
'information': 'Example information'
}
}
self.populate_test_event(event)
class TestMustInput(SearchQueryUnitTest):
@ -96,18 +106,7 @@ class TestAggregationInput(SearchQueryUnitTest):
class TestExecute(SearchQueryUnitTest):
def populate_example_event(self):
event = {
'summary': 'Test Summary',
'note': 'Example note',
'details': {
'information': 'Example information'
}
}
self.populate_test_event(event)
def test_complex_aggregation_query_execute(self):
self.setup()
query = SearchQuery()
assert query.date_timedelta == {}
query.add_must(ExistsMatch('ip'))
@ -200,6 +199,16 @@ class TestExecute(SearchQueryUnitTest):
assert results['aggregations']['ip']['terms'][1]['count'] == 2
assert results['aggregations']['ip']['terms'][1]['key'] == '1.2.3.4'
def test_aggregation_without_must_fields(self):
event = self.generate_default_event()
event['_source']['utctimestamp'] = event['_source']['utctimestamp']()
self.populate_test_event(event)
search_query = SearchQuery(minutes=10)
search_query.add_aggregation(Aggregation('summary'))
results = search_query.execute(self.es_client)
assert results['aggregations']['summary']['terms'][0]['count'] == 1
def test_aggregation_query_execute(self):
self.setup()
query = SearchQuery()
@ -440,3 +449,26 @@ class TestExecute(SearchQueryUnitTest):
query = SearchQuery(minutes=10)
with pytest.raises(AttributeError):
query.execute(self.es_client)
def test_execute_with_size(self):
for num in range(0, 30):
self.populate_example_event()
query = SearchQuery()
query.add_must(ExistsMatch('summary'))
results = query.execute(self.es_client, size=12)
assert len(results['hits']) == 12
def test_execute_without_size(self):
for num in range(0, 1200):
self.populate_example_event()
query = SearchQuery()
query.add_must(ExistsMatch('summary'))
results = query.execute(self.es_client)
assert len(results['hits']) == 1000
def test_execute_with_should(self):
self.populate_example_event()
self.query.add_should(ExistsMatch('summary'))
self.query.add_should(ExistsMatch('nonexistentfield'))
results = self.query.execute(self.es_client)
assert len(results['hits']) == 1

Просмотреть файл

@ -11,6 +11,8 @@ from datetime import datetime
from datetime import timedelta
from dateutil.parser import parse
import random
class UnitTestSuite(object):
def setup(self):
@ -53,6 +55,34 @@ class UnitTestSuite(object):
# self.es_client.delete_template('eventstemplate')
# self.es_client.delete_template('alertstemplate')
def random_ip(self):
return str(random.randint(1, 255)) + "." + str(random.randint(1, 255)) + "." + str(random.randint(1, 255)) + "." + str(random.randint(1, 255))
def generate_default_event(self):
current_timestamp = UnitTestSuite.current_timestamp_lambda()
source_ip = self.random_ip()
event = {
"_index": "events",
"_type": "event",
"_source": {
"category": "excategory",
"utctimestamp": current_timestamp,
"hostname": "exhostname",
"severity": "NOTICE",
"source": "exsource",
"summary": "Example summary",
"tags": ['tag1', 'tag2'],
"details": {
"sourceipaddress": source_ip,
"hostname": "exhostname"
}
}
}
return event
@staticmethod
def current_timestamp():
return toUTC(datetime.now()).isoformat()