[AIRFLOW-3212] Add AwsGlueCatalogPartitionSensor (#4112)

Adds AwsGlueCatalogPartitionSensor and AwsGlueCatalogHook with
supporting functions. Unit tests are included but rely on mocking since
Moto does not yet fully support AWS Glue Catalog at this time.
This commit is contained in:
Mike Mole 2019-01-11 14:35:08 -05:00 коммит произвёл Ash Berlin-Taylor
Родитель 2c8c7d93d1
Коммит 71dd6017e7
6 изменённых файлов: 442 добавлений и 1 удалений

Просмотреть файл

@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.contrib.hooks.aws_hook import AwsHook
class AwsGlueCatalogHook(AwsHook):
"""
Interact with AWS Glue Catalog
:param aws_conn_id: ID of the Airflow connection where
credentials and extra configuration are stored
:type aws_conn_id: str
:param region_name: aws region name (example: us-east-1)
:type region_name: str
"""
def __init__(self,
aws_conn_id='aws_default',
region_name=None,
*args,
**kwargs):
self.region_name = region_name
super(AwsGlueCatalogHook, self).__init__(aws_conn_id=aws_conn_id, *args, **kwargs)
def get_conn(self):
"""
Returns glue connection object.
"""
self.conn = self.get_client_type('glue', self.region_name)
return self.conn
def get_partitions(self,
database_name,
table_name,
expression='',
page_size=None,
max_items=None):
"""
Retrieves the partition values for a table.
:param database_name: The name of the catalog database where the partitions reside.
:type database_name: str
:param table_name: The name of the partitions' table.
:type table_name: str
:param expression: An expression filtering the partitions to be returned.
Please see official AWS documentation for further information.
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions
:type expression: str
:param page_size: pagination size
:type page_size: int
:param max_items: maximum items to return
:type max_items: int
:return: set of partition values where each value is a tuple since
a partition may be composed of multiple columns. For example:
{('2018-01-01','1'), ('2018-01-01','2')}
"""
config = {
'PageSize': page_size,
'MaxItems': max_items,
}
paginator = self.get_conn().get_paginator('get_partitions')
response = paginator.paginate(
DatabaseName=database_name,
TableName=table_name,
Expression=expression,
PaginationConfig=config
)
partitions = set()
for page in response:
for p in page['Partitions']:
partitions.add(tuple(p['Values']))
return partitions
def check_for_partition(self, database_name, table_name, expression):
"""
Checks whether a partition exists
:param database_name: Name of hive database (schema) @table belongs to
:type database_name: str
:param table_name: Name of hive table @partition belongs to
:type table_name: str
:expression: Expression that matches the partitions to check for
(eg `a = 'b' AND c = 'd'`)
:type expression: str
:rtype: bool
>>> hook = AwsGlueCatalogHook()
>>> t = 'static_babynames_partitioned'
>>> hook.check_for_partition('airflow', t, "ds='2015-01-01'")
True
"""
partitions = self.get_partitions(database_name, table_name, expression, max_items=1)
if partitions:
return True
else:
return False

Просмотреть файл

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
"""
Waits for a partition to show up in AWS Glue Catalog.
:param table_name: The name of the table to wait for, supports the dot
notation (my_database.my_table)
:type table_name: str
:param expression: The partition clause to wait for. This is passed as
is to the AWS Glue Catalog API's get_partitions function,
and supports SQL like notation as in ``ds='2015-01-01'
AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``.
See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html
#aws-glue-api-catalog-partitions-GetPartitions
:type expression: str
:param aws_conn_id: ID of the Airflow connection where
credentials and extra configuration are stored
:type aws_conn_id: str
:param region_name: Optional aws region name (example: us-east-1). Uses region from connection
if not specified.
:type region_name: str
:param database_name: The name of the catalog database where the partitions reside.
:type database_name: str
:param poke_interval: Time in seconds that the job should wait in
between each tries
:type poke_interval: int
"""
template_fields = ('database_name', 'table_name', 'expression',)
ui_color = '#C5CAE9'
@apply_defaults
def __init__(self,
table_name, expression="ds='{{ ds }}'",
aws_conn_id='aws_default',
region_name=None,
database_name='default',
poke_interval=60 * 3,
*args,
**kwargs):
super(AwsGlueCatalogPartitionSensor, self).__init__(
poke_interval=poke_interval, *args, **kwargs)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.table_name = table_name
self.expression = expression
self.database_name = database_name
def poke(self, context):
"""
Checks for existence of the partition in the AWS Glue Catalog table
"""
if '.' in self.table_name:
self.database_name, self.table_name = self.table_name.split('.')
self.log.info(
'Poking for table {self.database_name}.{self.table_name}, '
'expression {self.expression}'.format(**locals()))
return self.get_hook().check_for_partition(
self.database_name, self.table_name, self.expression)
def get_hook(self):
"""
Gets the AwsGlueCatalogHook
"""
if not hasattr(self, 'hook'):
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
self.hook = AwsGlueCatalogHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
return self.hook

Просмотреть файл

@ -236,6 +236,7 @@ Sensors
^^^^^^^
.. autoclass:: airflow.contrib.sensors.aws_athena_sensor.AthenaSensor
.. autoclass:: airflow.contrib.sensors.aws_glue_catalog_partition_sensor.AwsGlueCatalogPartitionSensor
.. autoclass:: airflow.contrib.sensors.aws_redshift_cluster_sensor.AwsRedshiftClusterSensor
.. autoclass:: airflow.contrib.sensors.azure_cosmos_sensor.AzureCosmosDocumentSensor
.. autoclass:: airflow.contrib.sensors.bash_sensor.BashSensor
@ -420,6 +421,7 @@ Community contributed hooks
.. autoclass:: airflow.contrib.hooks.aws_athena_hook.AWSAthenaHook
.. autoclass:: airflow.contrib.hooks.aws_dynamodb_hook.AwsDynamoDBHook
.. autoclass:: airflow.contrib.hooks.aws_firehose_hook.AwsFirehoseHook
.. autoclass:: airflow.contrib.hooks.aws_glue_catalog_hook.AwsGlueCatalogHook
.. autoclass:: airflow.contrib.hooks.aws_hook.AwsHook
.. autoclass:: airflow.contrib.hooks.aws_lambda_hook.AwsLambdaHook
.. autoclass:: airflow.contrib.hooks.aws_sns_hook.AwsSnsHook

Просмотреть файл

@ -251,7 +251,7 @@ devel = [
'lxml>=4.0.0',
'mock',
'mongomock',
'moto==1.1.19',
'moto==1.3.5',
'nose',
'nose-ignore-docstring==0.2',
'nose-timer',

Просмотреть файл

@ -0,0 +1,110 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
try:
from moto import mock_glue
except ImportError:
mock_glue = None
try:
from unittest import mock
except ImportError:
import mock
@unittest.skipIf(mock_glue is None,
"Skipping test because moto.mock_glue is not available")
class TestAwsGlueCatalogHook(unittest.TestCase):
@mock_glue
def test_get_conn_returns_a_boto3_connection(self):
hook = AwsGlueCatalogHook(region_name="us-east-1")
self.assertIsNotNone(hook.get_conn())
@mock_glue
def test_conn_id(self):
hook = AwsGlueCatalogHook(aws_conn_id='my_aws_conn_id', region_name="us-east-1")
self.assertEquals(hook.aws_conn_id, 'my_aws_conn_id')
@mock_glue
def test_region(self):
hook = AwsGlueCatalogHook(region_name="us-west-2")
self.assertEquals(hook.region_name, 'us-west-2')
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'get_conn')
def test_get_partitions_empty(self, mock_get_conn):
response = set()
mock_get_conn.get_paginator.paginate.return_value = response
hook = AwsGlueCatalogHook(region_name="us-east-1")
self.assertEquals(hook.get_partitions('db', 'tbl'), set())
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'get_conn')
def test_get_partitions(self, mock_get_conn):
response = [{
'Partitions': [{
'Values': ['2015-01-01']
}]
}]
mock_paginator = mock.Mock()
mock_paginator.paginate.return_value = response
mock_conn = mock.Mock()
mock_conn.get_paginator.return_value = mock_paginator
mock_get_conn.return_value = mock_conn
hook = AwsGlueCatalogHook(region_name="us-east-1")
result = hook.get_partitions('db',
'tbl',
expression='foo=bar',
page_size=2,
max_items=3)
self.assertEquals(result, set([('2015-01-01',)]))
mock_conn.get_paginator.assert_called_once_with('get_partitions')
mock_paginator.paginate.assert_called_once_with(DatabaseName='db',
TableName='tbl',
Expression='foo=bar',
PaginationConfig={
'PageSize': 2,
'MaxItems': 3})
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'get_partitions')
def test_check_for_partition(self, mock_get_partitions):
mock_get_partitions.return_value = set([('2018-01-01',)])
hook = AwsGlueCatalogHook(region_name="us-east-1")
self.assertTrue(hook.check_for_partition('db', 'tbl', 'expr'))
mock_get_partitions.assert_called_once_with('db', 'tbl', 'expr', max_items=1)
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'get_partitions')
def test_check_for_partition_false(self, mock_get_partitions):
mock_get_partitions.return_value = set()
hook = AwsGlueCatalogHook(region_name="us-east-1")
self.assertFalse(hook.check_for_partition('db', 'tbl', 'expr'))
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from airflow import configuration
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
from airflow.contrib.sensors.aws_glue_catalog_partition_sensor import AwsGlueCatalogPartitionSensor
try:
from moto import mock_glue
except ImportError:
mock_glue = None
try:
from unittest import mock
except ImportError:
import mock
@unittest.skipIf(mock_glue is None,
"Skipping test because moto.mock_glue is not available")
class TestAwsGlueCatalogPartitionSensor(unittest.TestCase):
task_id = 'test_glue_catalog_partition_sensor'
def setUp(self):
configuration.load_test_config()
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
def test_poke(self, mock_check_for_partition):
mock_check_for_partition.return_value = True
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
table_name='tbl')
self.assertTrue(op.poke(None))
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
def test_poke_false(self, mock_check_for_partition):
mock_check_for_partition.return_value = False
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
table_name='tbl')
self.assertFalse(op.poke(None))
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
def test_poke_default_args(self, mock_check_for_partition):
table_name = 'test_glue_catalog_partition_sensor_tbl'
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
table_name=table_name)
op.poke(None)
self.assertEqual(op.hook.region_name, None)
self.assertEqual(op.hook.aws_conn_id, 'aws_default')
mock_check_for_partition.assert_called_once_with('default',
table_name,
"ds='{{ ds }}'")
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
def test_poke_nondefault_args(self, mock_check_for_partition):
table_name = 'my_table'
expression = 'col=val'
aws_conn_id = 'my_aws_conn_id'
region_name = 'us-west-2'
database_name = 'my_db'
poke_interval = 2
timeout = 3
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
table_name=table_name,
expression=expression,
aws_conn_id=aws_conn_id,
region_name=region_name,
database_name=database_name,
poke_interval=poke_interval,
timeout=timeout)
op.poke(None)
self.assertEqual(op.hook.region_name, region_name)
self.assertEqual(op.hook.aws_conn_id, aws_conn_id)
self.assertEqual(op.poke_interval, poke_interval)
self.assertEqual(op.timeout, timeout)
mock_check_for_partition.assert_called_once_with(database_name,
table_name,
expression)
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
def test_dot_notation(self, mock_check_for_partition):
db_table = 'my_db.my_tbl'
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
table_name=db_table)
op.poke(None)
mock_check_for_partition.assert_called_once_with('my_db',
'my_tbl',
"ds='{{ ds }}'")
if __name__ == '__main__':
unittest.main()