[AIRFLOW-1795] Correctly call S3Hook after migration to boto3

In the migration of S3Hook to boto3 the connection
ID parameter changed
to `aws_conn_id`. This fixes the uses of
`s3_conn_id` in the code base
and adds a note to UPDATING.md about the change.

In correcting the tests for S3ToHiveTransfer I
noticed that
S3Hook.get_key was returning a dictionary, rather
then the S3.Object as
mentioned in it's doc string. The important thing
that was missing was
ability to get the key name from the return a call
to get_wildcard_key.

Closes #2795 from
ashb/AIRFLOW-1795-s3hook_boto3_fixes
This commit is contained in:
Ash Berlin-Taylor 2017-11-18 14:07:38 +01:00 коммит произвёл Bolke de Bruin
Родитель 54c03f3262
Коммит 98df0d6e3b
7 изменённых файлов: 69 добавлений и 55 удалений

Просмотреть файл

@ -6,12 +6,23 @@ assists people when migrating to a new version.
## Airflow 1.9
### SSH Hook updates, along with new SSH Operator & SFTP Operator
SSH Hook now uses Paramiko library to create ssh client connection, instead of sub-process based ssh command execution previously (<1.9.0), so this is backward incompatible.
SSH Hook now uses Paramiko library to create ssh client connection, instead of sub-process based ssh command execution previously (<1.9.0), so this is backward incompatible.
- update SSHHook constructor
- use SSHOperator class in place of SSHExecuteOperator which is removed now. Refer test_ssh_operator.py for usage info.
- SFTPOperator is added to perform secure file transfer from serverA to serverB. Refer test_sftp_operator.py.py for usage info.
- No updates are required if you are using ftpHook, it will continue work as is.
### S3Hook switched to use Boto3
The airflow.hooks.S3_hook.S3Hook has been switched to use boto3 instead of the older boto (a.k.a. boto2). This result in a few backwards incompatible changes to the following classes: S3Hook:
- the constructors no longer accepts `s3_conn_id`. It is now called `aws_conn_id`.
- the default conneciton is now "aws_default" instead of "s3_default"
- the return type of objects returned by `get_bucket` is now boto3.s3.Bucket
- the return type of `get_key`, and `get_wildcard_key` is now an boto3.S3.Object.
If you are using any of these in your DAGs and specify a connection ID you will need to update the parameter name for the connection to "aws_conn_id": S3ToHiveTransfer, S3PrefixSensor, S3KeySensor, RedshiftToS3Transfer.
### Logging update
The logging structure of Airflow has been rewritten to make configuration easier and the logging system more transparent.

Просмотреть файл

@ -123,7 +123,7 @@ class S3Hook(AwsHook):
def get_key(self, key, bucket_name=None):
"""
Returns a boto3.S3.Key object
Returns a boto3.s3.Object
:param key: the path to the key
:type key: str
@ -132,8 +132,10 @@ class S3Hook(AwsHook):
"""
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)
return self.get_conn().get_object(Bucket=bucket_name, Key=key)
obj = self.get_resource_type('s3').Object(bucket_name, key)
obj.load()
return obj
def read_key(self, key, bucket_name=None):
"""
@ -144,9 +146,9 @@ class S3Hook(AwsHook):
:param bucket_name: Name of the bucket in which the file is stored
:type bucket_name: str
"""
obj = self.get_key(key, bucket_name)
return obj['Body'].read().decode('utf-8')
return obj.get()['Body'].read().decode('utf-8')
def check_for_wildcard_key(self,
wildcard_key, bucket_name=None, delimiter=''):
@ -159,7 +161,7 @@ class S3Hook(AwsHook):
def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''):
"""
Returns a boto3.s3.Key object matching the regular expression
Returns a boto3.s3.Object object matching the regular expression
:param regex_key: the path to the key
:type regex_key: str

Просмотреть файл

@ -30,8 +30,8 @@ class RedshiftToS3Transfer(BaseOperator):
:type s3_key: string
:param redshift_conn_id: reference to a specific redshift database
:type redshift_conn_id: string
:param s3_conn_id: reference to a specific S3 connection
:type s3_conn_id: string
:param aws_conn_id: reference to a specific S3 connection
:type aws_conn_id: string
:param options: reference to a list of UNLOAD options
:type options: list
"""
@ -48,7 +48,7 @@ class RedshiftToS3Transfer(BaseOperator):
s3_bucket,
s3_key,
redshift_conn_id='redshift_default',
s3_conn_id='s3_default',
aws_conn_id='aws_default',
unload_options=tuple(),
autocommit=False,
parameters=None,
@ -59,14 +59,14 @@ class RedshiftToS3Transfer(BaseOperator):
self.s3_bucket = s3_bucket
self.s3_key = s3_key
self.redshift_conn_id = redshift_conn_id
self.s3_conn_id = s3_conn_id
self.aws_conn_id = aws_conn_id
self.unload_options = unload_options
self.autocommit = autocommit
self.parameters = parameters
def execute(self, context):
self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
self.s3 = S3Hook(aws_conn_id=self.aws_conn_id)
a_key, s_key = self.s3.get_credentials()
unload_options = '\n\t\t\t'.join(self.unload_options)

Просмотреть файл

@ -37,12 +37,12 @@ class S3FileTransformOperator(BaseOperator):
:param source_s3_key: The key to be retrieved from S3
:type source_s3_key: str
:param source_s3_conn_id: source s3 connection
:type source_s3_conn_id: str
:param source_aws_conn_id: source s3 connection
:type source_aws_conn_id: str
:param dest_s3_key: The key to be written from S3
:type dest_s3_key: str
:param dest_s3_conn_id: destination s3 connection
:type dest_s3_conn_id: str
:param dest_aws_conn_id: destination s3 connection
:type dest_aws_conn_id: str
:param replace: Replace dest S3 key if it already exists
:type replace: bool
:param transform_script: location of the executable transformation script
@ -59,21 +59,21 @@ class S3FileTransformOperator(BaseOperator):
source_s3_key,
dest_s3_key,
transform_script,
source_s3_conn_id='s3_default',
dest_s3_conn_id='s3_default',
source_aws_conn_id='aws_default',
dest_aws_conn_id='aws_default',
replace=False,
*args, **kwargs):
super(S3FileTransformOperator, self).__init__(*args, **kwargs)
self.source_s3_key = source_s3_key
self.source_s3_conn_id = source_s3_conn_id
self.source_aws_conn_id = source_aws_conn_id
self.dest_s3_key = dest_s3_key
self.dest_s3_conn_id = dest_s3_conn_id
self.dest_aws_conn_id = dest_aws_conn_id
self.replace = replace
self.transform_script = transform_script
def execute(self, context):
source_s3 = S3Hook(s3_conn_id=self.source_s3_conn_id)
dest_s3 = S3Hook(s3_conn_id=self.dest_s3_conn_id)
source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id)
dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id)
self.log.info("Downloading source S3 file %s", self.source_s3_key)
if not source_s3.check_for_key(self.source_s3_key):
raise AirflowException("The source key {0} does not exist".format(self.source_s3_key))

Просмотреть файл

@ -71,8 +71,8 @@ class S3ToHiveTransfer(BaseOperator):
:type wildcard_match: bool
:param delimiter: field delimiter in the file
:type delimiter: str
:param s3_conn_id: source s3 connection
:type s3_conn_id: str
:param aws_conn_id: source s3 connection
:type aws_conn_id: str
:param hive_cli_conn_id: destination hive connection
:type hive_cli_conn_id: str
:param input_compressed: Boolean to determine if file decompression is
@ -99,7 +99,7 @@ class S3ToHiveTransfer(BaseOperator):
headers=False,
check_headers=False,
wildcard_match=False,
s3_conn_id='s3_default',
aws_conn_id='aws_default',
hive_cli_conn_id='hive_cli_default',
input_compressed=False,
tblproperties=None,
@ -116,7 +116,7 @@ class S3ToHiveTransfer(BaseOperator):
self.check_headers = check_headers
self.wildcard_match = wildcard_match
self.hive_cli_conn_id = hive_cli_conn_id
self.s3_conn_id = s3_conn_id
self.aws_conn_id = aws_conn_id
self.input_compressed = input_compressed
self.tblproperties = tblproperties
@ -127,7 +127,7 @@ class S3ToHiveTransfer(BaseOperator):
def execute(self, context):
# Downloading file from S3
self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
self.s3 = S3Hook(aws_conn_id=self.aws_conn_id)
self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
self.log.info("Downloading S3 file")
@ -143,14 +143,13 @@ class S3ToHiveTransfer(BaseOperator):
s3_key_object = self.s3.get_key(self.s3_key)
root, file_ext = os.path.splitext(s3_key_object.key)
with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
NamedTemporaryFile(mode="w",
NamedTemporaryFile(mode="wb",
dir=tmp_dir,
suffix=file_ext) as f:
self.log.info("Dumping S3 key {0} contents to local file {1}"
.format(s3_key_object.key, f.name))
s3_key_object.get_contents_to_file(f)
s3_key_object.download_fileobj(f)
f.flush()
self.s3.connection.close()
if not self.headers:
self.log.info("Loading file %s into Hive", f.name)
self.hive.load_file(

Просмотреть файл

@ -501,8 +501,8 @@ class S3KeySensor(BaseSensorOperator):
:param wildcard_match: whether the bucket_key should be interpreted as a
Unix wildcard pattern
:type wildcard_match: bool
:param s3_conn_id: a reference to the s3 connection
:type s3_conn_id: str
:param aws_conn_id: a reference to the s3 connection
:type aws_conn_id: str
"""
template_fields = ('bucket_key', 'bucket_name')
@ -511,7 +511,7 @@ class S3KeySensor(BaseSensorOperator):
self, bucket_key,
bucket_name=None,
wildcard_match=False,
s3_conn_id='s3_default',
aws_conn_id='aws_default',
*args, **kwargs):
super(S3KeySensor, self).__init__(*args, **kwargs)
# Parse
@ -528,11 +528,11 @@ class S3KeySensor(BaseSensorOperator):
self.bucket_name = bucket_name
self.bucket_key = bucket_key
self.wildcard_match = wildcard_match
self.s3_conn_id = s3_conn_id
self.aws_conn_id = aws_conn_id
def poke(self, context):
from airflow.hooks.S3_hook import S3Hook
hook = S3Hook(s3_conn_id=self.s3_conn_id)
hook = S3Hook(aws_conn_id=self.aws_conn_id)
full_url = "s3://" + self.bucket_name + "/" + self.bucket_key
self.log.info('Poking for key : {full_url}'.format(**locals()))
if self.wildcard_match:
@ -565,7 +565,7 @@ class S3PrefixSensor(BaseSensorOperator):
def __init__(
self, bucket_name,
prefix, delimiter='/',
s3_conn_id='s3_default',
aws_conn_id='aws_default',
*args, **kwargs):
super(S3PrefixSensor, self).__init__(*args, **kwargs)
# Parse
@ -573,13 +573,13 @@ class S3PrefixSensor(BaseSensorOperator):
self.prefix = prefix
self.delimiter = delimiter
self.full_url = "s3://" + bucket_name + '/' + prefix
self.s3_conn_id = s3_conn_id
self.aws_conn_id = aws_conn_id
def poke(self, context):
self.log.info('Poking for prefix : {self.prefix}\n'
'in bucket s3://{self.bucket_name}'.format(**locals()))
'in bucket s3://{self.bucket_name}'.format(**locals()))
from airflow.hooks.S3_hook import S3Hook
hook = S3Hook(s3_conn_id=self.s3_conn_id)
hook = S3Hook(aws_conn_id=self.aws_conn_id)
return hook.check_for_prefix(
prefix=self.prefix,
delimiter=self.delimiter,

Просмотреть файл

@ -32,6 +32,12 @@ import shutil
import filecmp
import errno
try:
import boto3
from moto import mock_s3
except ImportError:
mock_s3 = None
class S3ToHiveTransferTest(unittest.TestCase):
@ -128,10 +134,6 @@ class S3ToHiveTransferTest(unittest.TestCase):
key = ext + "_" + ('h' if header else 'nh')
return key
def _cp_file_contents(self, fn_src, fn_dest):
with open(fn_src, 'rb') as f_src, open(fn_dest, 'wb') as f_dest:
shutil.copyfileobj(f_src, f_dest)
def _check_file_equality(self, fn_1, fn_2, ext):
# gz files contain mtime and filename in the header that
# causes filecmp to return False even if contents are identical
@ -205,13 +207,15 @@ class S3ToHiveTransferTest(unittest.TestCase):
msg="bz2 Compressed file not as expected")
@unittest.skipIf(mock is None, 'mock package not present')
@unittest.skipIf(mock_s3 is None, 'moto package not present')
@mock.patch('airflow.operators.s3_to_hive_operator.HiveCliHook')
@mock.patch('airflow.operators.s3_to_hive_operator.S3Hook')
def test_execute(self, mock_s3hook, mock_hiveclihook):
@mock_s3
def test_execute(self, mock_hiveclihook):
conn = boto3.client('s3')
conn.create_bucket(Bucket='bucket')
# Testing txt, zip, bz2 files with and without header row
for test in product(['.txt', '.gz', '.bz2'], [True, False]):
ext = test[0]
has_header = test[1]
for (ext, has_header) in product(['.txt', '.gz', '.bz2'], [True, False]):
self.kwargs['headers'] = has_header
self.kwargs['check_headers'] = has_header
logging.info("Testing {0} format {1} header".
@ -219,15 +223,13 @@ class S3ToHiveTransferTest(unittest.TestCase):
('with' if has_header else 'without'))
)
self.kwargs['input_compressed'] = (False if ext == '.txt' else True)
self.kwargs['s3_key'] = self.s3_key + ext
self.kwargs['s3_key'] = 's3://bucket/' + self.s3_key + ext
ip_fn = self._get_fn(ext, self.kwargs['headers'])
op_fn = self._get_fn(ext, False)
# Mock s3 object returned by S3Hook
mock_s3_object = mock.Mock(key=self.kwargs['s3_key'])
mock_s3_object.get_contents_to_file.side_effect = \
lambda dest_file: \
self._cp_file_contents(ip_fn, dest_file.name)
mock_s3hook().get_key.return_value = mock_s3_object
# Upload the file into the Mocked S3 bucket
conn.upload_file(ip_fn, 'bucket', self.s3_key + ext)
# file paramter to HiveCliHook.load_file is compared
# against expected file oputput
mock_hiveclihook().load_file.side_effect = \