[AIRFLOW-3212] Add AwsGlueCatalogPartitionSensor (#4112)
Adds AwsGlueCatalogPartitionSensor and AwsGlueCatalogHook with supporting functions. Unit tests are included but rely on mocking since Moto does not yet fully support AWS Glue Catalog at this time.
This commit is contained in:
Родитель
2c8c7d93d1
Коммит
71dd6017e7
|
@ -0,0 +1,118 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
|
||||
|
||||
class AwsGlueCatalogHook(AwsHook):
|
||||
"""
|
||||
Interact with AWS Glue Catalog
|
||||
|
||||
:param aws_conn_id: ID of the Airflow connection where
|
||||
credentials and extra configuration are stored
|
||||
:type aws_conn_id: str
|
||||
:param region_name: aws region name (example: us-east-1)
|
||||
:type region_name: str
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
aws_conn_id='aws_default',
|
||||
region_name=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
self.region_name = region_name
|
||||
super(AwsGlueCatalogHook, self).__init__(aws_conn_id=aws_conn_id, *args, **kwargs)
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
Returns glue connection object.
|
||||
"""
|
||||
self.conn = self.get_client_type('glue', self.region_name)
|
||||
return self.conn
|
||||
|
||||
def get_partitions(self,
|
||||
database_name,
|
||||
table_name,
|
||||
expression='',
|
||||
page_size=None,
|
||||
max_items=None):
|
||||
"""
|
||||
Retrieves the partition values for a table.
|
||||
|
||||
:param database_name: The name of the catalog database where the partitions reside.
|
||||
:type database_name: str
|
||||
:param table_name: The name of the partitions' table.
|
||||
:type table_name: str
|
||||
:param expression: An expression filtering the partitions to be returned.
|
||||
Please see official AWS documentation for further information.
|
||||
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions
|
||||
:type expression: str
|
||||
:param page_size: pagination size
|
||||
:type page_size: int
|
||||
:param max_items: maximum items to return
|
||||
:type max_items: int
|
||||
:return: set of partition values where each value is a tuple since
|
||||
a partition may be composed of multiple columns. For example:
|
||||
{('2018-01-01','1'), ('2018-01-01','2')}
|
||||
"""
|
||||
config = {
|
||||
'PageSize': page_size,
|
||||
'MaxItems': max_items,
|
||||
}
|
||||
|
||||
paginator = self.get_conn().get_paginator('get_partitions')
|
||||
response = paginator.paginate(
|
||||
DatabaseName=database_name,
|
||||
TableName=table_name,
|
||||
Expression=expression,
|
||||
PaginationConfig=config
|
||||
)
|
||||
|
||||
partitions = set()
|
||||
for page in response:
|
||||
for p in page['Partitions']:
|
||||
partitions.add(tuple(p['Values']))
|
||||
|
||||
return partitions
|
||||
|
||||
def check_for_partition(self, database_name, table_name, expression):
|
||||
"""
|
||||
Checks whether a partition exists
|
||||
|
||||
:param database_name: Name of hive database (schema) @table belongs to
|
||||
:type database_name: str
|
||||
:param table_name: Name of hive table @partition belongs to
|
||||
:type table_name: str
|
||||
:expression: Expression that matches the partitions to check for
|
||||
(eg `a = 'b' AND c = 'd'`)
|
||||
:type expression: str
|
||||
:rtype: bool
|
||||
|
||||
>>> hook = AwsGlueCatalogHook()
|
||||
>>> t = 'static_babynames_partitioned'
|
||||
>>> hook.check_for_partition('airflow', t, "ds='2015-01-01'")
|
||||
True
|
||||
"""
|
||||
partitions = self.get_partitions(database_name, table_name, expression, max_items=1)
|
||||
|
||||
if partitions:
|
||||
return True
|
||||
else:
|
||||
return False
|
|
@ -0,0 +1,93 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
|
||||
"""
|
||||
Waits for a partition to show up in AWS Glue Catalog.
|
||||
|
||||
:param table_name: The name of the table to wait for, supports the dot
|
||||
notation (my_database.my_table)
|
||||
:type table_name: str
|
||||
:param expression: The partition clause to wait for. This is passed as
|
||||
is to the AWS Glue Catalog API's get_partitions function,
|
||||
and supports SQL like notation as in ``ds='2015-01-01'
|
||||
AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``.
|
||||
See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html
|
||||
#aws-glue-api-catalog-partitions-GetPartitions
|
||||
:type expression: str
|
||||
:param aws_conn_id: ID of the Airflow connection where
|
||||
credentials and extra configuration are stored
|
||||
:type aws_conn_id: str
|
||||
:param region_name: Optional aws region name (example: us-east-1). Uses region from connection
|
||||
if not specified.
|
||||
:type region_name: str
|
||||
:param database_name: The name of the catalog database where the partitions reside.
|
||||
:type database_name: str
|
||||
:param poke_interval: Time in seconds that the job should wait in
|
||||
between each tries
|
||||
:type poke_interval: int
|
||||
"""
|
||||
template_fields = ('database_name', 'table_name', 'expression',)
|
||||
ui_color = '#C5CAE9'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
table_name, expression="ds='{{ ds }}'",
|
||||
aws_conn_id='aws_default',
|
||||
region_name=None,
|
||||
database_name='default',
|
||||
poke_interval=60 * 3,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(AwsGlueCatalogPartitionSensor, self).__init__(
|
||||
poke_interval=poke_interval, *args, **kwargs)
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.region_name = region_name
|
||||
self.table_name = table_name
|
||||
self.expression = expression
|
||||
self.database_name = database_name
|
||||
|
||||
def poke(self, context):
|
||||
"""
|
||||
Checks for existence of the partition in the AWS Glue Catalog table
|
||||
"""
|
||||
if '.' in self.table_name:
|
||||
self.database_name, self.table_name = self.table_name.split('.')
|
||||
self.log.info(
|
||||
'Poking for table {self.database_name}.{self.table_name}, '
|
||||
'expression {self.expression}'.format(**locals()))
|
||||
|
||||
return self.get_hook().check_for_partition(
|
||||
self.database_name, self.table_name, self.expression)
|
||||
|
||||
def get_hook(self):
|
||||
"""
|
||||
Gets the AwsGlueCatalogHook
|
||||
"""
|
||||
if not hasattr(self, 'hook'):
|
||||
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
|
||||
self.hook = AwsGlueCatalogHook(
|
||||
aws_conn_id=self.aws_conn_id,
|
||||
region_name=self.region_name)
|
||||
|
||||
return self.hook
|
|
@ -236,6 +236,7 @@ Sensors
|
|||
^^^^^^^
|
||||
|
||||
.. autoclass:: airflow.contrib.sensors.aws_athena_sensor.AthenaSensor
|
||||
.. autoclass:: airflow.contrib.sensors.aws_glue_catalog_partition_sensor.AwsGlueCatalogPartitionSensor
|
||||
.. autoclass:: airflow.contrib.sensors.aws_redshift_cluster_sensor.AwsRedshiftClusterSensor
|
||||
.. autoclass:: airflow.contrib.sensors.azure_cosmos_sensor.AzureCosmosDocumentSensor
|
||||
.. autoclass:: airflow.contrib.sensors.bash_sensor.BashSensor
|
||||
|
@ -420,6 +421,7 @@ Community contributed hooks
|
|||
.. autoclass:: airflow.contrib.hooks.aws_athena_hook.AWSAthenaHook
|
||||
.. autoclass:: airflow.contrib.hooks.aws_dynamodb_hook.AwsDynamoDBHook
|
||||
.. autoclass:: airflow.contrib.hooks.aws_firehose_hook.AwsFirehoseHook
|
||||
.. autoclass:: airflow.contrib.hooks.aws_glue_catalog_hook.AwsGlueCatalogHook
|
||||
.. autoclass:: airflow.contrib.hooks.aws_hook.AwsHook
|
||||
.. autoclass:: airflow.contrib.hooks.aws_lambda_hook.AwsLambdaHook
|
||||
.. autoclass:: airflow.contrib.hooks.aws_sns_hook.AwsSnsHook
|
||||
|
|
2
setup.py
2
setup.py
|
@ -251,7 +251,7 @@ devel = [
|
|||
'lxml>=4.0.0',
|
||||
'mock',
|
||||
'mongomock',
|
||||
'moto==1.1.19',
|
||||
'moto==1.3.5',
|
||||
'nose',
|
||||
'nose-ignore-docstring==0.2',
|
||||
'nose-timer',
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
|
||||
|
||||
try:
|
||||
from moto import mock_glue
|
||||
except ImportError:
|
||||
mock_glue = None
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
import mock
|
||||
|
||||
|
||||
@unittest.skipIf(mock_glue is None,
|
||||
"Skipping test because moto.mock_glue is not available")
|
||||
class TestAwsGlueCatalogHook(unittest.TestCase):
|
||||
|
||||
@mock_glue
|
||||
def test_get_conn_returns_a_boto3_connection(self):
|
||||
hook = AwsGlueCatalogHook(region_name="us-east-1")
|
||||
self.assertIsNotNone(hook.get_conn())
|
||||
|
||||
@mock_glue
|
||||
def test_conn_id(self):
|
||||
hook = AwsGlueCatalogHook(aws_conn_id='my_aws_conn_id', region_name="us-east-1")
|
||||
self.assertEquals(hook.aws_conn_id, 'my_aws_conn_id')
|
||||
|
||||
@mock_glue
|
||||
def test_region(self):
|
||||
hook = AwsGlueCatalogHook(region_name="us-west-2")
|
||||
self.assertEquals(hook.region_name, 'us-west-2')
|
||||
|
||||
@mock_glue
|
||||
@mock.patch.object(AwsGlueCatalogHook, 'get_conn')
|
||||
def test_get_partitions_empty(self, mock_get_conn):
|
||||
response = set()
|
||||
mock_get_conn.get_paginator.paginate.return_value = response
|
||||
hook = AwsGlueCatalogHook(region_name="us-east-1")
|
||||
|
||||
self.assertEquals(hook.get_partitions('db', 'tbl'), set())
|
||||
|
||||
@mock_glue
|
||||
@mock.patch.object(AwsGlueCatalogHook, 'get_conn')
|
||||
def test_get_partitions(self, mock_get_conn):
|
||||
response = [{
|
||||
'Partitions': [{
|
||||
'Values': ['2015-01-01']
|
||||
}]
|
||||
}]
|
||||
mock_paginator = mock.Mock()
|
||||
mock_paginator.paginate.return_value = response
|
||||
mock_conn = mock.Mock()
|
||||
mock_conn.get_paginator.return_value = mock_paginator
|
||||
mock_get_conn.return_value = mock_conn
|
||||
hook = AwsGlueCatalogHook(region_name="us-east-1")
|
||||
result = hook.get_partitions('db',
|
||||
'tbl',
|
||||
expression='foo=bar',
|
||||
page_size=2,
|
||||
max_items=3)
|
||||
|
||||
self.assertEquals(result, set([('2015-01-01',)]))
|
||||
mock_conn.get_paginator.assert_called_once_with('get_partitions')
|
||||
mock_paginator.paginate.assert_called_once_with(DatabaseName='db',
|
||||
TableName='tbl',
|
||||
Expression='foo=bar',
|
||||
PaginationConfig={
|
||||
'PageSize': 2,
|
||||
'MaxItems': 3})
|
||||
|
||||
@mock_glue
|
||||
@mock.patch.object(AwsGlueCatalogHook, 'get_partitions')
|
||||
def test_check_for_partition(self, mock_get_partitions):
|
||||
mock_get_partitions.return_value = set([('2018-01-01',)])
|
||||
hook = AwsGlueCatalogHook(region_name="us-east-1")
|
||||
|
||||
self.assertTrue(hook.check_for_partition('db', 'tbl', 'expr'))
|
||||
mock_get_partitions.assert_called_once_with('db', 'tbl', 'expr', max_items=1)
|
||||
|
||||
@mock_glue
|
||||
@mock.patch.object(AwsGlueCatalogHook, 'get_partitions')
|
||||
def test_check_for_partition_false(self, mock_get_partitions):
|
||||
mock_get_partitions.return_value = set()
|
||||
hook = AwsGlueCatalogHook(region_name="us-east-1")
|
||||
|
||||
self.assertFalse(hook.check_for_partition('db', 'tbl', 'expr'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,118 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from airflow import configuration
|
||||
from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
|
||||
from airflow.contrib.sensors.aws_glue_catalog_partition_sensor import AwsGlueCatalogPartitionSensor
|
||||
|
||||
try:
|
||||
from moto import mock_glue
|
||||
except ImportError:
|
||||
mock_glue = None
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
import mock
|
||||
|
||||
|
||||
@unittest.skipIf(mock_glue is None,
|
||||
"Skipping test because moto.mock_glue is not available")
|
||||
class TestAwsGlueCatalogPartitionSensor(unittest.TestCase):
|
||||
|
||||
task_id = 'test_glue_catalog_partition_sensor'
|
||||
|
||||
def setUp(self):
|
||||
configuration.load_test_config()
|
||||
|
||||
@mock_glue
|
||||
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
|
||||
def test_poke(self, mock_check_for_partition):
|
||||
mock_check_for_partition.return_value = True
|
||||
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
|
||||
table_name='tbl')
|
||||
self.assertTrue(op.poke(None))
|
||||
|
||||
@mock_glue
|
||||
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
|
||||
def test_poke_false(self, mock_check_for_partition):
|
||||
mock_check_for_partition.return_value = False
|
||||
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
|
||||
table_name='tbl')
|
||||
self.assertFalse(op.poke(None))
|
||||
|
||||
@mock_glue
|
||||
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
|
||||
def test_poke_default_args(self, mock_check_for_partition):
|
||||
table_name = 'test_glue_catalog_partition_sensor_tbl'
|
||||
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
|
||||
table_name=table_name)
|
||||
op.poke(None)
|
||||
|
||||
self.assertEqual(op.hook.region_name, None)
|
||||
self.assertEqual(op.hook.aws_conn_id, 'aws_default')
|
||||
mock_check_for_partition.assert_called_once_with('default',
|
||||
table_name,
|
||||
"ds='{{ ds }}'")
|
||||
|
||||
@mock_glue
|
||||
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
|
||||
def test_poke_nondefault_args(self, mock_check_for_partition):
|
||||
table_name = 'my_table'
|
||||
expression = 'col=val'
|
||||
aws_conn_id = 'my_aws_conn_id'
|
||||
region_name = 'us-west-2'
|
||||
database_name = 'my_db'
|
||||
poke_interval = 2
|
||||
timeout = 3
|
||||
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
|
||||
table_name=table_name,
|
||||
expression=expression,
|
||||
aws_conn_id=aws_conn_id,
|
||||
region_name=region_name,
|
||||
database_name=database_name,
|
||||
poke_interval=poke_interval,
|
||||
timeout=timeout)
|
||||
op.poke(None)
|
||||
|
||||
self.assertEqual(op.hook.region_name, region_name)
|
||||
self.assertEqual(op.hook.aws_conn_id, aws_conn_id)
|
||||
self.assertEqual(op.poke_interval, poke_interval)
|
||||
self.assertEqual(op.timeout, timeout)
|
||||
mock_check_for_partition.assert_called_once_with(database_name,
|
||||
table_name,
|
||||
expression)
|
||||
|
||||
@mock_glue
|
||||
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
|
||||
def test_dot_notation(self, mock_check_for_partition):
|
||||
db_table = 'my_db.my_tbl'
|
||||
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
|
||||
table_name=db_table)
|
||||
op.poke(None)
|
||||
|
||||
mock_check_for_partition.assert_called_once_with('my_db',
|
||||
'my_tbl',
|
||||
"ds='{{ ds }}'")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче