[AIRFLOW-2640] Add Cassandra table sensor
Just like a partition sensor for Hive, this PR adds a sensor that waits for a table to be created in Cassandra cluster. Closes #3518 from sekikn/AIRFLOW-2640
This commit is contained in:
Родитель
6ed2fb716a
Коммит
5f49ebf018
|
@ -158,6 +158,21 @@ class CassandraHook(BaseHook, LoggingMixin):
|
|||
child_policy_args)
|
||||
return TokenAwarePolicy(child_policy)
|
||||
|
||||
def table_exists(self, table):
|
||||
"""
|
||||
Checks if a table exists in Cassandra
|
||||
|
||||
:param table: Target Cassandra table.
|
||||
Use dot notation to target a specific keyspace.
|
||||
:type table: string
|
||||
"""
|
||||
keyspace = self.keyspace
|
||||
if '.' in table:
|
||||
keyspace, table = table.split('.', 1)
|
||||
cluster_metadata = self.get_conn().cluster.metadata
|
||||
return (keyspace in cluster_metadata.keyspaces and
|
||||
table in cluster_metadata.keyspaces[keyspace].tables)
|
||||
|
||||
def record_exists(self, table, keys):
|
||||
"""
|
||||
Checks if a record exists in Cassandra
|
||||
|
@ -168,12 +183,12 @@ class CassandraHook(BaseHook, LoggingMixin):
|
|||
:param keys: The keys and their values to check the existence.
|
||||
:type keys: dict
|
||||
"""
|
||||
keyspace = None
|
||||
keyspace = self.keyspace
|
||||
if '.' in table:
|
||||
keyspace, table = table.split('.', 1)
|
||||
ks = " AND ".join("{}=%({})s".format(key, key) for key in keys.keys())
|
||||
cql = "SELECT * FROM {keyspace}.{table} WHERE {keys}".format(
|
||||
keyspace=(keyspace or self.keyspace), table=table, keys=ks)
|
||||
keyspace=keyspace, table=table, keys=ks)
|
||||
|
||||
try:
|
||||
rs = self.get_conn().execute(cql, keys)
|
||||
|
|
|
@ -29,9 +29,10 @@ class CassandraRecordSensor(BaseSensorOperator):
|
|||
primary keys 'p1' and 'p2' to be populated in keyspace 'k' and table 't',
|
||||
instantiate it as follows:
|
||||
|
||||
>>> CassandraRecordSensor(table="k.t", keys={"p1": "v1", "p2": "v2"},
|
||||
... cassandra_conn_id="cassandra_default", task_id="cassandra_sensor")
|
||||
<Task(CassandraRecordSensor): cassandra_sensor>
|
||||
>>> cassandra_sensor = CassandraRecordSensor(table="k.t",
|
||||
... keys={"p1": "v1", "p2": "v2"},
|
||||
... cassandra_conn_id="cassandra_default",
|
||||
... task_id="cassandra_sensor")
|
||||
"""
|
||||
template_fields = ('table', 'keys')
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from airflow.contrib.hooks.cassandra_hook import CassandraHook
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class CassandraTableSensor(BaseSensorOperator):
|
||||
"""
|
||||
Checks for the existence of a table in a Cassandra cluster.
|
||||
|
||||
For example, if you want to wait for a table called 't' to be created
|
||||
in a keyspace 'k', instantiate it as follows:
|
||||
|
||||
>>> cassandra_sensor = CassandraTableSensor(table="k.t",
|
||||
... cassandra_conn_id="cassandra_default",
|
||||
... task_id="cassandra_sensor")
|
||||
"""
|
||||
template_fields = ('table',)
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self, table, cassandra_conn_id, *args, **kwargs):
|
||||
"""
|
||||
Create a new CassandraTableSensor
|
||||
|
||||
:param table: Target Cassandra table.
|
||||
Use dot notation to target a specific keyspace.
|
||||
:type table: string
|
||||
:param cassandra_conn_id: The connection ID to use
|
||||
when connecting to Cassandra cluster
|
||||
:type cassandra_conn_id: string
|
||||
"""
|
||||
super(CassandraTableSensor, self).__init__(*args, **kwargs)
|
||||
self.cassandra_conn_id = cassandra_conn_id
|
||||
self.table = table
|
||||
|
||||
def poke(self, context):
|
||||
self.log.info('Sensor check existence of table: %s', self.table)
|
||||
hook = CassandraHook(self.cassandra_conn_id)
|
||||
return hook.table_exists(self.table)
|
|
@ -198,7 +198,8 @@ Sensors
|
|||
.. autoclass:: airflow.contrib.sensors.aws_redshift_cluster_sensor.AwsRedshiftClusterSensor
|
||||
.. autoclass:: airflow.contrib.sensors.bash_sensor.BashSensor
|
||||
.. autoclass:: airflow.contrib.sensors.bigquery_sensor.BigQueryTableSensor
|
||||
.. autoclass:: airflow.contrib.sensors.cassandra_sensor.CassandraRecordSensor
|
||||
.. autoclass:: airflow.contrib.sensors.cassandra_record_sensor.CassandraRecordSensor
|
||||
.. autoclass:: airflow.contrib.sensors.cassandra_table_sensor.CassandraTableSensor
|
||||
.. autoclass:: airflow.contrib.sensors.datadog_sensor.DatadogSensor
|
||||
.. autoclass:: airflow.contrib.sensors.emr_base_sensor.EmrBaseSensor
|
||||
.. autoclass:: airflow.contrib.sensors.emr_job_flow_sensor.EmrJobFlowSensor
|
||||
|
|
|
@ -39,6 +39,25 @@ class CassandraHookTest(unittest.TestCase):
|
|||
conn_id='cassandra_test', conn_type='cassandra',
|
||||
host='host-1,host-2', port='9042', schema='test_keyspace',
|
||||
extra='{"load_balancing_policy":"TokenAwarePolicy"}'))
|
||||
db.merge_conn(
|
||||
models.Connection(
|
||||
conn_id='cassandra_default_with_schema', conn_type='cassandra',
|
||||
host='localhost', port='9042', schema='s'))
|
||||
|
||||
hook = CassandraHook("cassandra_default")
|
||||
session = hook.get_conn()
|
||||
cqls = [
|
||||
"DROP SCHEMA IF EXISTS s",
|
||||
"""
|
||||
CREATE SCHEMA s WITH REPLICATION =
|
||||
{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }
|
||||
""",
|
||||
]
|
||||
for cql in cqls:
|
||||
session.execute(cql)
|
||||
|
||||
session.shutdown()
|
||||
hook.shutdown_cluster()
|
||||
|
||||
def test_get_conn(self):
|
||||
with mock.patch.object(Cluster, "connect") as mock_connect, \
|
||||
|
@ -117,16 +136,10 @@ class CassandraHookTest(unittest.TestCase):
|
|||
thrown = True
|
||||
self.assertEqual(should_throw, thrown)
|
||||
|
||||
def test_record_exists(self):
|
||||
hook = CassandraHook()
|
||||
def test_record_exists_with_keyspace_from_cql(self):
|
||||
hook = CassandraHook("cassandra_default")
|
||||
session = hook.get_conn()
|
||||
|
||||
cqls = [
|
||||
"DROP SCHEMA IF EXISTS s",
|
||||
"""
|
||||
CREATE SCHEMA s WITH REPLICATION =
|
||||
{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }
|
||||
""",
|
||||
"DROP TABLE IF EXISTS s.t",
|
||||
"CREATE TABLE s.t (pk1 text, pk2 text, c text, PRIMARY KEY (pk1, pk2))",
|
||||
"INSERT INTO s.t (pk1, pk2, c) VALUES ('foo', 'bar', 'baz')",
|
||||
|
@ -137,6 +150,58 @@ class CassandraHookTest(unittest.TestCase):
|
|||
self.assertTrue(hook.record_exists("s.t", {"pk1": "foo", "pk2": "bar"}))
|
||||
self.assertFalse(hook.record_exists("s.t", {"pk1": "foo", "pk2": "baz"}))
|
||||
|
||||
session.shutdown()
|
||||
hook.shutdown_cluster()
|
||||
|
||||
def test_record_exists_with_keyspace_from_session(self):
|
||||
hook = CassandraHook("cassandra_default_with_schema")
|
||||
session = hook.get_conn()
|
||||
cqls = [
|
||||
"DROP TABLE IF EXISTS t",
|
||||
"CREATE TABLE t (pk1 text, pk2 text, c text, PRIMARY KEY (pk1, pk2))",
|
||||
"INSERT INTO t (pk1, pk2, c) VALUES ('foo', 'bar', 'baz')",
|
||||
]
|
||||
for cql in cqls:
|
||||
session.execute(cql)
|
||||
|
||||
self.assertTrue(hook.record_exists("t", {"pk1": "foo", "pk2": "bar"}))
|
||||
self.assertFalse(hook.record_exists("t", {"pk1": "foo", "pk2": "baz"}))
|
||||
|
||||
session.shutdown()
|
||||
hook.shutdown_cluster()
|
||||
|
||||
def test_table_exists_with_keyspace_from_cql(self):
|
||||
hook = CassandraHook("cassandra_default")
|
||||
session = hook.get_conn()
|
||||
cqls = [
|
||||
"DROP TABLE IF EXISTS s.t",
|
||||
"CREATE TABLE s.t (pk1 text PRIMARY KEY)",
|
||||
]
|
||||
for cql in cqls:
|
||||
session.execute(cql)
|
||||
|
||||
self.assertTrue(hook.table_exists("s.t"))
|
||||
self.assertFalse(hook.table_exists("s.u"))
|
||||
|
||||
session.shutdown()
|
||||
hook.shutdown_cluster()
|
||||
|
||||
def test_table_exists_with_keyspace_from_session(self):
|
||||
hook = CassandraHook("cassandra_default_with_schema")
|
||||
session = hook.get_conn()
|
||||
cqls = [
|
||||
"DROP TABLE IF EXISTS t",
|
||||
"CREATE TABLE t (pk1 text PRIMARY KEY)",
|
||||
]
|
||||
for cql in cqls:
|
||||
session.execute(cql)
|
||||
|
||||
self.assertTrue(hook.table_exists("t"))
|
||||
self.assertFalse(hook.table_exists("u"))
|
||||
|
||||
session.shutdown()
|
||||
hook.shutdown_cluster()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -24,7 +24,8 @@ from mock import patch
|
|||
|
||||
from airflow import DAG
|
||||
from airflow import configuration
|
||||
from airflow.contrib.sensors.cassandra_sensor import CassandraRecordSensor
|
||||
from airflow.contrib.sensors.cassandra_record_sensor import CassandraRecordSensor
|
||||
from airflow.contrib.sensors.cassandra_table_sensor import CassandraTableSensor
|
||||
from airflow.utils import timezone
|
||||
|
||||
|
||||
|
@ -54,5 +55,27 @@ class TestCassandraRecordSensor(unittest.TestCase):
|
|||
mock_record_exists.assert_called_once_with('t', {'foo': 'bar'})
|
||||
|
||||
|
||||
class TestCassandraTableSensor(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
configuration.load_test_config()
|
||||
args = {
|
||||
'owner': 'airflow',
|
||||
'start_date': DEFAULT_DATE
|
||||
}
|
||||
self.dag = DAG('test_dag_id', default_args=args)
|
||||
self.sensor = CassandraTableSensor(
|
||||
task_id='test_task',
|
||||
cassandra_conn_id='cassandra_default',
|
||||
dag=self.dag,
|
||||
table='t',
|
||||
)
|
||||
|
||||
@patch("airflow.contrib.hooks.cassandra_hook.CassandraHook.table_exists")
|
||||
def test_poke(self, mock_table_exists):
|
||||
self.sensor.poke(None)
|
||||
mock_table_exists.assert_called_once_with('t')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче