Adapt QueryInterface specifically for localities since it isn't used for events anymore

This commit is contained in:
Emma Rose 2019-08-13 19:12:24 -04:00
Родитель efcc19de8e
Коммит fa6b8da8a3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 1486642516ED3535
4 изменённых файлов: 88 добавлений и 174 удалений

Просмотреть файл

@ -5,7 +5,6 @@ from mozdef_util.elasticsearch_client import ElasticsearchClient as ESClient
from mozdef_util.query_models import SearchQuery, TermMatch
import alerts.geomodel.config as config
import alerts.geomodel.query as query
# Default radius (in Kilometres) that a locality should have.
@ -13,10 +12,6 @@ _DEFAULT_RADIUS_KM = 50.0
# TODO: Switch to dataclasses when we move to Python3.7+
def _dict_take(dictionary, keys):
return {key: dictionary[key] for key in keys}
class Locality(NamedTuple):
'''Represents a specific locality.
'''
@ -29,7 +24,6 @@ class Locality(NamedTuple):
longitude: float
radius: int
class State(NamedTuple):
'''Represents the state tracked for each user regarding their localities.
'''
@ -49,23 +43,6 @@ class Entry(NamedTuple):
identifier: Optional[str]
state: State
JournalInterface = Callable[[List[Entry], str], None]
def wrap_journal(client: ESClient) -> JournalInterface:
'''Wrap an `ElasticsearchClient` in a closure of type `JournalInterface`.
'''
def wrapper(entries: List[Entry], esindex: str):
for entry in entries:
document = dict(entry.state._asdict())
client.save_object(
index=esindex,
body=document,
doc_id=entry.identifier)
return wrapper
class Update(NamedTuple):
'''Produced by calls to functions operating on lists of `State`s to
indicate when an update was applied without having to maintain distinct
@ -86,6 +63,12 @@ class Update(NamedTuple):
return Update(new.state, u.did_update or new.did_update)
JournalInterface = Callable[[List[Entry], str], None]
QueryInterface = Callable[[SearchQuery, str], List[Entry]]
def _dict_take(dictionary, keys):
return {key: dictionary[key] for key in keys}
def _update(state: State, from_evt: State) -> Update:
did_update = False
@ -115,6 +98,52 @@ def _update(state: State, from_evt: State) -> Update:
return Update(state, did_update)
def wrap_journal(client: ESClient) -> JournalInterface:
'''Wrap an `ElasticsearchClient` in a closure of type `JournalInterface`.
'''
def wrapper(entries: List[Entry], esindex: str):
for entry in entries:
document = dict(entry.state._asdict())
client.save_object(
index=esindex,
body=document,
doc_id=entry.identifier)
return wrapper
def wrap_query(client: ESClient) -> QueryInterface:
'''Wrap an `ElasticsearchClient` in a closure of type `QueryInterface`.
'''
def to_state(result: Dict[str, Any]) -> Optional[State]:
try:
result['localities'] = [
Locality(**_dict_take(loc, Locality._fields))
for loc in result['localities']
]
return State(**_dict_take(result, State._fields))
except TypeError:
return None
except KeyError:
return None
def wrapper(query: SearchQuery, esindex: str) -> List[Entry]:
results = query.execute(client, indices=[esindex]).get('hits', [])
entries = []
for event in results:
opt_state = to_state(event.get('_source', {}))
if opt_state is not None:
entries.append(Entry(event['_id'], opt_state))
return entries
return wrapper
def from_event(
event: Dict[str, Any],
radius=_DEFAULT_RADIUS_KM
@ -151,38 +180,16 @@ def from_event(
radius)
def find_all(
query_es: query.QueryInterface,
query_es: QueryInterface,
locality: config.Localities
) -> List[Entry]:
'''Retrieve all locality state from ElasticSearch.
'''
def to_state(result: Dict[str, Any]) -> Optional[State]:
try:
result['localities'] = [
Locality(**_dict_take(loc, Locality._fields))
for loc in result['localities']
]
return State(**_dict_take(result, State._fields))
except TypeError:
return None
except KeyError:
return None
search = SearchQuery()
search.add_must([TermMatch('type_', 'locality')])
results = query_es(search, locality.es_index)
entries = []
for result in results:
state = to_state(result['_source'])
if state is not None:
entries.append(Entry(result['_id'], state))
return entries
return query_es(search, locality.es_index)
def merge(persisted: List[State], event_sourced: List[State]) -> List[Update]:
'''Merge together a list of states already stored in ElasticSearch

Просмотреть файл

@ -1,23 +0,0 @@
'''To make GeoModel code more testable, we abstract interaction with
ElasticSearch away via a "query interface". This is just a function that,
called with an ES index and a `SearchQuery`, produces a list of dictionaries
as output.
'''
from typing import Any, Callable, Dict, List
from mozdef_util.elasticsearch_client import ElasticsearchClient as ESClient
from mozdef_util.query_models import SearchQuery
QueryInterface = Callable[[SearchQuery, str], List[Dict[str, Any]]]
def wrap(client: ESClient) -> QueryInterface:
'''Wrap an `ElasticsearchClient` in a closure of type `QueryInterface`.
'''
def wrapper(query: SearchQuery, esindex: str) -> List[Dict[str, Any]]:
return query.execute(client, indices=[esindex]).get('hits', [])
return wrapper

Просмотреть файл

@ -1,15 +1,25 @@
from datetime import datetime, timedelta
import pytz
from typing import Any, Dict, List
import unittest
from mozdef_util.query_models import SearchQuery
import alerts.geomodel.config as config
import alerts.geomodel.locality as locality
import alerts.geomodel.query as query
from tests.alerts.geomodel.util import query_interface
from tests.unit_test_suite import UnitTestSuite
def query_interface(results: List[locality.Entry]) -> locality.QueryInterface:
'''Produce a `QueryInterface` that just returns the provided results.
'''
def closure(q: SearchQuery, esi: str) -> List[locality.Entry]:
return results
return closure
class TestLocalityElasticSearch(UnitTestSuite):
'''Tests for the `locality` module that interact with ES.
'''
@ -38,7 +48,7 @@ class TestLocalityElasticSearch(UnitTestSuite):
self.refresh(self.event_index_name)
query_iface = query.wrap(self.es_client)
query_iface = locality.wrap_query(self.es_client)
loc_cfg = config.Localities(self.event_index_name, 30, 50.0)
entries = locality.find_all(query_iface, loc_cfg)
@ -88,7 +98,7 @@ class TestLocalityElasticSearch(UnitTestSuite):
self.refresh(self.event_index_name)
query_iface = query.wrap(self.es_client)
query_iface = locality.wrap_query(self.es_client)
loc_cfg = config.Localities(self.event_index_name, 30, 50.0)
retrieved = locality.find_all(query_iface, loc_cfg)
@ -102,42 +112,26 @@ class TestLocality(unittest.TestCase):
def test_find_all_retrieves_all_states(self):
query_iface = query_interface([
{
'_id': 'id1',
'_source': {
'type_': 'locality',
'username': 'tester1',
'localities': [
{
'sourceipaddress': '1.2.3.4',
'city': 'Toronto',
'country': 'CA',
'lastaction': datetime.utcnow(),
'latitude': 43.6529,
'longitude': -79.3849,
'radius': 50
}
]
}
},
{
'_id': 'id2',
'_source': {
'type_': 'locality',
'username': 'tester2',
'localities': [
{
'sourceipaddress': '4.3.2.1',
'city': 'San Francisco',
'country': 'USA',
'lastaction': datetime.utcnow(),
'latitude': 37.773972,
'longitude': -122.431297,
'radius': 50
}
]
}
}
locality.Entry('id1', locality.State('locality', 'tester1', [
locality.Locality(
sourceipaddress='1.2.3.4',
city='Toronto',
country='CA',
lastaction=datetime.utcnow() - timedelta(minutes=3),
latitude=43.6529,
longitude=-79.3849,
radius=50)
])),
locality.Entry('id2', locality.State('locality', 'tester2', [
locality.Locality(
sourceipaddress='4.3.2.1',
city='San Francisco',
country='USA',
lastaction=datetime.utcnow(),
latitude=37.773972,
longitude=-122.431297,
radius=50)
]))
])
loc_cfg = config.Localities('localities', 30, 50.0)
@ -152,55 +146,6 @@ class TestLocality(unittest.TestCase):
assert len(entries[0].state.localities) == 1
assert len(entries[1].state.localities) == 1
def test_find_all_ignores_invalid_data(self):
query_iface = query_interface([
# Invalid top-level State
{
'_id': 'id1',
'_source': {
'type__': 'locality', # Should have only one underscore (_)
'username': 'tester',
'localities': []
}
},
# Valid State
{
'_id': 'id2',
'_source': {
'type_': 'locality',
'username': 'validtester',
'localities': []
}
},
# Invalid locality data
{
'_id': 'id3',
'_source': {
'type_': 'locality',
'username': 'tester2',
'localities': [
{
# Should be sourceipaddress; missing a 'd'
'sourceipadress': '1.2.3.4',
'city': 'San Francisco',
'country': 'USA',
'lastaction': datetime.utcnow(),
'latitude': 37.773972,
'longitude': -122.431297,
'radius': 50
}
]
}
}
])
loc_cfg = config.Localities('localities', 30, 50.0)
entries = locality.find_all(query_iface, loc_cfg)
usernames = [entry.state.username for entry in entries]
assert len(entries) == 1
assert usernames == ['validtester']
def test_merge_updates_localities(self):
from_es = [
locality.State('locality', 'user1', [

Просмотреть файл

@ -1,15 +0,0 @@
from typing import Any, Dict, List
from mozdef_util.query_models import SearchQuery
import alerts.geomodel.query as query
def query_interface(results: List[Dict[str, Any]]) -> query.QueryInterface:
'''Produce a `QueryInterface` that just returns the provided results.
'''
def closure(q: SearchQuery, esi: str) -> List[Dict[str, Any]]:
return results
return closure