From 5f49ebf0182c19ffde6e125dcf426afa4db02786 Mon Sep 17 00:00:00 2001 From: Kengo Seki Date: Wed, 20 Jun 2018 20:36:32 +0200 Subject: [PATCH] [AIRFLOW-2640] Add Cassandra table sensor Just like a partition sensor for Hive, this PR adds a sensor that waits for a table to be created in Cassandra cluster. Closes #3518 from sekikn/AIRFLOW-2640 --- airflow/contrib/hooks/cassandra_hook.py | 19 ++++- ...a_sensor.py => cassandra_record_sensor.py} | 7 +- .../contrib/sensors/cassandra_table_sensor.py | 56 +++++++++++++ docs/code.rst | 3 +- tests/contrib/hooks/test_cassandra_hook.py | 81 +++++++++++++++++-- .../contrib/sensors/test_cassandra_sensor.py | 25 +++++- 6 files changed, 176 insertions(+), 15 deletions(-) rename airflow/contrib/sensors/{cassandra_sensor.py => cassandra_record_sensor.py} (88%) create mode 100644 airflow/contrib/sensors/cassandra_table_sensor.py diff --git a/airflow/contrib/hooks/cassandra_hook.py b/airflow/contrib/hooks/cassandra_hook.py index 704ba0d8d0..0e0b47708d 100644 --- a/airflow/contrib/hooks/cassandra_hook.py +++ b/airflow/contrib/hooks/cassandra_hook.py @@ -158,6 +158,21 @@ class CassandraHook(BaseHook, LoggingMixin): child_policy_args) return TokenAwarePolicy(child_policy) + def table_exists(self, table): + """ + Checks if a table exists in Cassandra + + :param table: Target Cassandra table. + Use dot notation to target a specific keyspace. + :type table: string + """ + keyspace = self.keyspace + if '.' in table: + keyspace, table = table.split('.', 1) + cluster_metadata = self.get_conn().cluster.metadata + return (keyspace in cluster_metadata.keyspaces and + table in cluster_metadata.keyspaces[keyspace].tables) + def record_exists(self, table, keys): """ Checks if a record exists in Cassandra @@ -168,12 +183,12 @@ class CassandraHook(BaseHook, LoggingMixin): :param keys: The keys and their values to check the existence. :type keys: dict """ - keyspace = None + keyspace = self.keyspace if '.' in table: keyspace, table = table.split('.', 1) ks = " AND ".join("{}=%({})s".format(key, key) for key in keys.keys()) cql = "SELECT * FROM {keyspace}.{table} WHERE {keys}".format( - keyspace=(keyspace or self.keyspace), table=table, keys=ks) + keyspace=keyspace, table=table, keys=ks) try: rs = self.get_conn().execute(cql, keys) diff --git a/airflow/contrib/sensors/cassandra_sensor.py b/airflow/contrib/sensors/cassandra_record_sensor.py similarity index 88% rename from airflow/contrib/sensors/cassandra_sensor.py rename to airflow/contrib/sensors/cassandra_record_sensor.py index aef66122e9..493a6ba6b1 100644 --- a/airflow/contrib/sensors/cassandra_sensor.py +++ b/airflow/contrib/sensors/cassandra_record_sensor.py @@ -29,9 +29,10 @@ class CassandraRecordSensor(BaseSensorOperator): primary keys 'p1' and 'p2' to be populated in keyspace 'k' and table 't', instantiate it as follows: - >>> CassandraRecordSensor(table="k.t", keys={"p1": "v1", "p2": "v2"}, - ... cassandra_conn_id="cassandra_default", task_id="cassandra_sensor") - + >>> cassandra_sensor = CassandraRecordSensor(table="k.t", + ... keys={"p1": "v1", "p2": "v2"}, + ... cassandra_conn_id="cassandra_default", + ... task_id="cassandra_sensor") """ template_fields = ('table', 'keys') diff --git a/airflow/contrib/sensors/cassandra_table_sensor.py b/airflow/contrib/sensors/cassandra_table_sensor.py new file mode 100644 index 0000000000..5a85995aca --- /dev/null +++ b/airflow/contrib/sensors/cassandra_table_sensor.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from airflow.contrib.hooks.cassandra_hook import CassandraHook +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class CassandraTableSensor(BaseSensorOperator): + """ + Checks for the existence of a table in a Cassandra cluster. + + For example, if you want to wait for a table called 't' to be created + in a keyspace 'k', instantiate it as follows: + + >>> cassandra_sensor = CassandraTableSensor(table="k.t", + ... cassandra_conn_id="cassandra_default", + ... task_id="cassandra_sensor") + """ + template_fields = ('table',) + + @apply_defaults + def __init__(self, table, cassandra_conn_id, *args, **kwargs): + """ + Create a new CassandraTableSensor + + :param table: Target Cassandra table. + Use dot notation to target a specific keyspace. + :type table: string + :param cassandra_conn_id: The connection ID to use + when connecting to Cassandra cluster + :type cassandra_conn_id: string + """ + super(CassandraTableSensor, self).__init__(*args, **kwargs) + self.cassandra_conn_id = cassandra_conn_id + self.table = table + + def poke(self, context): + self.log.info('Sensor check existence of table: %s', self.table) + hook = CassandraHook(self.cassandra_conn_id) + return hook.table_exists(self.table) diff --git a/docs/code.rst b/docs/code.rst index f055fc60cf..a64c2779d2 100644 --- a/docs/code.rst +++ b/docs/code.rst @@ -198,7 +198,8 @@ Sensors .. autoclass:: airflow.contrib.sensors.aws_redshift_cluster_sensor.AwsRedshiftClusterSensor .. autoclass:: airflow.contrib.sensors.bash_sensor.BashSensor .. autoclass:: airflow.contrib.sensors.bigquery_sensor.BigQueryTableSensor -.. autoclass:: airflow.contrib.sensors.cassandra_sensor.CassandraRecordSensor +.. autoclass:: airflow.contrib.sensors.cassandra_record_sensor.CassandraRecordSensor +.. autoclass:: airflow.contrib.sensors.cassandra_table_sensor.CassandraTableSensor .. autoclass:: airflow.contrib.sensors.datadog_sensor.DatadogSensor .. autoclass:: airflow.contrib.sensors.emr_base_sensor.EmrBaseSensor .. autoclass:: airflow.contrib.sensors.emr_job_flow_sensor.EmrJobFlowSensor diff --git a/tests/contrib/hooks/test_cassandra_hook.py b/tests/contrib/hooks/test_cassandra_hook.py index e420ec0095..9cb0739993 100644 --- a/tests/contrib/hooks/test_cassandra_hook.py +++ b/tests/contrib/hooks/test_cassandra_hook.py @@ -39,6 +39,25 @@ class CassandraHookTest(unittest.TestCase): conn_id='cassandra_test', conn_type='cassandra', host='host-1,host-2', port='9042', schema='test_keyspace', extra='{"load_balancing_policy":"TokenAwarePolicy"}')) + db.merge_conn( + models.Connection( + conn_id='cassandra_default_with_schema', conn_type='cassandra', + host='localhost', port='9042', schema='s')) + + hook = CassandraHook("cassandra_default") + session = hook.get_conn() + cqls = [ + "DROP SCHEMA IF EXISTS s", + """ + CREATE SCHEMA s WITH REPLICATION = + { 'class' : 'SimpleStrategy', 'replication_factor' : 1 } + """, + ] + for cql in cqls: + session.execute(cql) + + session.shutdown() + hook.shutdown_cluster() def test_get_conn(self): with mock.patch.object(Cluster, "connect") as mock_connect, \ @@ -117,16 +136,10 @@ class CassandraHookTest(unittest.TestCase): thrown = True self.assertEqual(should_throw, thrown) - def test_record_exists(self): - hook = CassandraHook() + def test_record_exists_with_keyspace_from_cql(self): + hook = CassandraHook("cassandra_default") session = hook.get_conn() - cqls = [ - "DROP SCHEMA IF EXISTS s", - """ - CREATE SCHEMA s WITH REPLICATION = - { 'class' : 'SimpleStrategy', 'replication_factor' : 1 } - """, "DROP TABLE IF EXISTS s.t", "CREATE TABLE s.t (pk1 text, pk2 text, c text, PRIMARY KEY (pk1, pk2))", "INSERT INTO s.t (pk1, pk2, c) VALUES ('foo', 'bar', 'baz')", @@ -137,6 +150,58 @@ class CassandraHookTest(unittest.TestCase): self.assertTrue(hook.record_exists("s.t", {"pk1": "foo", "pk2": "bar"})) self.assertFalse(hook.record_exists("s.t", {"pk1": "foo", "pk2": "baz"})) + session.shutdown() + hook.shutdown_cluster() + + def test_record_exists_with_keyspace_from_session(self): + hook = CassandraHook("cassandra_default_with_schema") + session = hook.get_conn() + cqls = [ + "DROP TABLE IF EXISTS t", + "CREATE TABLE t (pk1 text, pk2 text, c text, PRIMARY KEY (pk1, pk2))", + "INSERT INTO t (pk1, pk2, c) VALUES ('foo', 'bar', 'baz')", + ] + for cql in cqls: + session.execute(cql) + + self.assertTrue(hook.record_exists("t", {"pk1": "foo", "pk2": "bar"})) + self.assertFalse(hook.record_exists("t", {"pk1": "foo", "pk2": "baz"})) + + session.shutdown() + hook.shutdown_cluster() + + def test_table_exists_with_keyspace_from_cql(self): + hook = CassandraHook("cassandra_default") + session = hook.get_conn() + cqls = [ + "DROP TABLE IF EXISTS s.t", + "CREATE TABLE s.t (pk1 text PRIMARY KEY)", + ] + for cql in cqls: + session.execute(cql) + + self.assertTrue(hook.table_exists("s.t")) + self.assertFalse(hook.table_exists("s.u")) + + session.shutdown() + hook.shutdown_cluster() + + def test_table_exists_with_keyspace_from_session(self): + hook = CassandraHook("cassandra_default_with_schema") + session = hook.get_conn() + cqls = [ + "DROP TABLE IF EXISTS t", + "CREATE TABLE t (pk1 text PRIMARY KEY)", + ] + for cql in cqls: + session.execute(cql) + + self.assertTrue(hook.table_exists("t")) + self.assertFalse(hook.table_exists("u")) + + session.shutdown() + hook.shutdown_cluster() + if __name__ == '__main__': unittest.main() diff --git a/tests/contrib/sensors/test_cassandra_sensor.py b/tests/contrib/sensors/test_cassandra_sensor.py index 0f0e7f5eb3..c07bc0be2a 100644 --- a/tests/contrib/sensors/test_cassandra_sensor.py +++ b/tests/contrib/sensors/test_cassandra_sensor.py @@ -24,7 +24,8 @@ from mock import patch from airflow import DAG from airflow import configuration -from airflow.contrib.sensors.cassandra_sensor import CassandraRecordSensor +from airflow.contrib.sensors.cassandra_record_sensor import CassandraRecordSensor +from airflow.contrib.sensors.cassandra_table_sensor import CassandraTableSensor from airflow.utils import timezone @@ -54,5 +55,27 @@ class TestCassandraRecordSensor(unittest.TestCase): mock_record_exists.assert_called_once_with('t', {'foo': 'bar'}) +class TestCassandraTableSensor(unittest.TestCase): + + def setUp(self): + configuration.load_test_config() + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + } + self.dag = DAG('test_dag_id', default_args=args) + self.sensor = CassandraTableSensor( + task_id='test_task', + cassandra_conn_id='cassandra_default', + dag=self.dag, + table='t', + ) + + @patch("airflow.contrib.hooks.cassandra_hook.CassandraHook.table_exists") + def test_poke(self, mock_table_exists): + self.sensor.poke(None) + mock_table_exists.assert_called_once_with('t') + + if __name__ == '__main__': unittest.main()