[AIRFLOW-2640] Add Cassandra table sensor

Just like a partition sensor for Hive,
this PR adds a sensor that waits for
a table to be created in Cassandra cluster.

Closes #3518 from sekikn/AIRFLOW-2640
This commit is contained in:
Kengo Seki 2018-06-20 20:36:32 +02:00 коммит произвёл Fokko Driesprong
Родитель 6ed2fb716a
Коммит 5f49ebf018
6 изменённых файлов: 176 добавлений и 15 удалений

Просмотреть файл

@ -158,6 +158,21 @@ class CassandraHook(BaseHook, LoggingMixin):
child_policy_args) child_policy_args)
return TokenAwarePolicy(child_policy) return TokenAwarePolicy(child_policy)
def table_exists(self, table):
"""
Checks if a table exists in Cassandra
:param table: Target Cassandra table.
Use dot notation to target a specific keyspace.
:type table: string
"""
keyspace = self.keyspace
if '.' in table:
keyspace, table = table.split('.', 1)
cluster_metadata = self.get_conn().cluster.metadata
return (keyspace in cluster_metadata.keyspaces and
table in cluster_metadata.keyspaces[keyspace].tables)
def record_exists(self, table, keys): def record_exists(self, table, keys):
""" """
Checks if a record exists in Cassandra Checks if a record exists in Cassandra
@ -168,12 +183,12 @@ class CassandraHook(BaseHook, LoggingMixin):
:param keys: The keys and their values to check the existence. :param keys: The keys and their values to check the existence.
:type keys: dict :type keys: dict
""" """
keyspace = None keyspace = self.keyspace
if '.' in table: if '.' in table:
keyspace, table = table.split('.', 1) keyspace, table = table.split('.', 1)
ks = " AND ".join("{}=%({})s".format(key, key) for key in keys.keys()) ks = " AND ".join("{}=%({})s".format(key, key) for key in keys.keys())
cql = "SELECT * FROM {keyspace}.{table} WHERE {keys}".format( cql = "SELECT * FROM {keyspace}.{table} WHERE {keys}".format(
keyspace=(keyspace or self.keyspace), table=table, keys=ks) keyspace=keyspace, table=table, keys=ks)
try: try:
rs = self.get_conn().execute(cql, keys) rs = self.get_conn().execute(cql, keys)

Просмотреть файл

@ -29,9 +29,10 @@ class CassandraRecordSensor(BaseSensorOperator):
primary keys 'p1' and 'p2' to be populated in keyspace 'k' and table 't', primary keys 'p1' and 'p2' to be populated in keyspace 'k' and table 't',
instantiate it as follows: instantiate it as follows:
>>> CassandraRecordSensor(table="k.t", keys={"p1": "v1", "p2": "v2"}, >>> cassandra_sensor = CassandraRecordSensor(table="k.t",
... cassandra_conn_id="cassandra_default", task_id="cassandra_sensor") ... keys={"p1": "v1", "p2": "v2"},
<Task(CassandraRecordSensor): cassandra_sensor> ... cassandra_conn_id="cassandra_default",
... task_id="cassandra_sensor")
""" """
template_fields = ('table', 'keys') template_fields = ('table', 'keys')

Просмотреть файл

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.cassandra_hook import CassandraHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
class CassandraTableSensor(BaseSensorOperator):
"""
Checks for the existence of a table in a Cassandra cluster.
For example, if you want to wait for a table called 't' to be created
in a keyspace 'k', instantiate it as follows:
>>> cassandra_sensor = CassandraTableSensor(table="k.t",
... cassandra_conn_id="cassandra_default",
... task_id="cassandra_sensor")
"""
template_fields = ('table',)
@apply_defaults
def __init__(self, table, cassandra_conn_id, *args, **kwargs):
"""
Create a new CassandraTableSensor
:param table: Target Cassandra table.
Use dot notation to target a specific keyspace.
:type table: string
:param cassandra_conn_id: The connection ID to use
when connecting to Cassandra cluster
:type cassandra_conn_id: string
"""
super(CassandraTableSensor, self).__init__(*args, **kwargs)
self.cassandra_conn_id = cassandra_conn_id
self.table = table
def poke(self, context):
self.log.info('Sensor check existence of table: %s', self.table)
hook = CassandraHook(self.cassandra_conn_id)
return hook.table_exists(self.table)

Просмотреть файл

@ -198,7 +198,8 @@ Sensors
.. autoclass:: airflow.contrib.sensors.aws_redshift_cluster_sensor.AwsRedshiftClusterSensor .. autoclass:: airflow.contrib.sensors.aws_redshift_cluster_sensor.AwsRedshiftClusterSensor
.. autoclass:: airflow.contrib.sensors.bash_sensor.BashSensor .. autoclass:: airflow.contrib.sensors.bash_sensor.BashSensor
.. autoclass:: airflow.contrib.sensors.bigquery_sensor.BigQueryTableSensor .. autoclass:: airflow.contrib.sensors.bigquery_sensor.BigQueryTableSensor
.. autoclass:: airflow.contrib.sensors.cassandra_sensor.CassandraRecordSensor .. autoclass:: airflow.contrib.sensors.cassandra_record_sensor.CassandraRecordSensor
.. autoclass:: airflow.contrib.sensors.cassandra_table_sensor.CassandraTableSensor
.. autoclass:: airflow.contrib.sensors.datadog_sensor.DatadogSensor .. autoclass:: airflow.contrib.sensors.datadog_sensor.DatadogSensor
.. autoclass:: airflow.contrib.sensors.emr_base_sensor.EmrBaseSensor .. autoclass:: airflow.contrib.sensors.emr_base_sensor.EmrBaseSensor
.. autoclass:: airflow.contrib.sensors.emr_job_flow_sensor.EmrJobFlowSensor .. autoclass:: airflow.contrib.sensors.emr_job_flow_sensor.EmrJobFlowSensor

Просмотреть файл

@ -39,6 +39,25 @@ class CassandraHookTest(unittest.TestCase):
conn_id='cassandra_test', conn_type='cassandra', conn_id='cassandra_test', conn_type='cassandra',
host='host-1,host-2', port='9042', schema='test_keyspace', host='host-1,host-2', port='9042', schema='test_keyspace',
extra='{"load_balancing_policy":"TokenAwarePolicy"}')) extra='{"load_balancing_policy":"TokenAwarePolicy"}'))
db.merge_conn(
models.Connection(
conn_id='cassandra_default_with_schema', conn_type='cassandra',
host='localhost', port='9042', schema='s'))
hook = CassandraHook("cassandra_default")
session = hook.get_conn()
cqls = [
"DROP SCHEMA IF EXISTS s",
"""
CREATE SCHEMA s WITH REPLICATION =
{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }
""",
]
for cql in cqls:
session.execute(cql)
session.shutdown()
hook.shutdown_cluster()
def test_get_conn(self): def test_get_conn(self):
with mock.patch.object(Cluster, "connect") as mock_connect, \ with mock.patch.object(Cluster, "connect") as mock_connect, \
@ -117,16 +136,10 @@ class CassandraHookTest(unittest.TestCase):
thrown = True thrown = True
self.assertEqual(should_throw, thrown) self.assertEqual(should_throw, thrown)
def test_record_exists(self): def test_record_exists_with_keyspace_from_cql(self):
hook = CassandraHook() hook = CassandraHook("cassandra_default")
session = hook.get_conn() session = hook.get_conn()
cqls = [ cqls = [
"DROP SCHEMA IF EXISTS s",
"""
CREATE SCHEMA s WITH REPLICATION =
{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }
""",
"DROP TABLE IF EXISTS s.t", "DROP TABLE IF EXISTS s.t",
"CREATE TABLE s.t (pk1 text, pk2 text, c text, PRIMARY KEY (pk1, pk2))", "CREATE TABLE s.t (pk1 text, pk2 text, c text, PRIMARY KEY (pk1, pk2))",
"INSERT INTO s.t (pk1, pk2, c) VALUES ('foo', 'bar', 'baz')", "INSERT INTO s.t (pk1, pk2, c) VALUES ('foo', 'bar', 'baz')",
@ -137,6 +150,58 @@ class CassandraHookTest(unittest.TestCase):
self.assertTrue(hook.record_exists("s.t", {"pk1": "foo", "pk2": "bar"})) self.assertTrue(hook.record_exists("s.t", {"pk1": "foo", "pk2": "bar"}))
self.assertFalse(hook.record_exists("s.t", {"pk1": "foo", "pk2": "baz"})) self.assertFalse(hook.record_exists("s.t", {"pk1": "foo", "pk2": "baz"}))
session.shutdown()
hook.shutdown_cluster()
def test_record_exists_with_keyspace_from_session(self):
hook = CassandraHook("cassandra_default_with_schema")
session = hook.get_conn()
cqls = [
"DROP TABLE IF EXISTS t",
"CREATE TABLE t (pk1 text, pk2 text, c text, PRIMARY KEY (pk1, pk2))",
"INSERT INTO t (pk1, pk2, c) VALUES ('foo', 'bar', 'baz')",
]
for cql in cqls:
session.execute(cql)
self.assertTrue(hook.record_exists("t", {"pk1": "foo", "pk2": "bar"}))
self.assertFalse(hook.record_exists("t", {"pk1": "foo", "pk2": "baz"}))
session.shutdown()
hook.shutdown_cluster()
def test_table_exists_with_keyspace_from_cql(self):
hook = CassandraHook("cassandra_default")
session = hook.get_conn()
cqls = [
"DROP TABLE IF EXISTS s.t",
"CREATE TABLE s.t (pk1 text PRIMARY KEY)",
]
for cql in cqls:
session.execute(cql)
self.assertTrue(hook.table_exists("s.t"))
self.assertFalse(hook.table_exists("s.u"))
session.shutdown()
hook.shutdown_cluster()
def test_table_exists_with_keyspace_from_session(self):
hook = CassandraHook("cassandra_default_with_schema")
session = hook.get_conn()
cqls = [
"DROP TABLE IF EXISTS t",
"CREATE TABLE t (pk1 text PRIMARY KEY)",
]
for cql in cqls:
session.execute(cql)
self.assertTrue(hook.table_exists("t"))
self.assertFalse(hook.table_exists("u"))
session.shutdown()
hook.shutdown_cluster()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Просмотреть файл

@ -24,7 +24,8 @@ from mock import patch
from airflow import DAG from airflow import DAG
from airflow import configuration from airflow import configuration
from airflow.contrib.sensors.cassandra_sensor import CassandraRecordSensor from airflow.contrib.sensors.cassandra_record_sensor import CassandraRecordSensor
from airflow.contrib.sensors.cassandra_table_sensor import CassandraTableSensor
from airflow.utils import timezone from airflow.utils import timezone
@ -54,5 +55,27 @@ class TestCassandraRecordSensor(unittest.TestCase):
mock_record_exists.assert_called_once_with('t', {'foo': 'bar'}) mock_record_exists.assert_called_once_with('t', {'foo': 'bar'})
class TestCassandraTableSensor(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
args = {
'owner': 'airflow',
'start_date': DEFAULT_DATE
}
self.dag = DAG('test_dag_id', default_args=args)
self.sensor = CassandraTableSensor(
task_id='test_task',
cassandra_conn_id='cassandra_default',
dag=self.dag,
table='t',
)
@patch("airflow.contrib.hooks.cassandra_hook.CassandraHook.table_exists")
def test_poke(self, mock_table_exists):
self.sensor.poke(None)
mock_table_exists.assert_called_once_with('t')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()