From 98df0d6e3b2e2b439ab46d6c9ba736777202414a Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Sat, 18 Nov 2017 14:07:38 +0100 Subject: [PATCH] [AIRFLOW-1795] Correctly call S3Hook after migration to boto3 In the migration of S3Hook to boto3 the connection ID parameter changed to `aws_conn_id`. This fixes the uses of `s3_conn_id` in the code base and adds a note to UPDATING.md about the change. In correcting the tests for S3ToHiveTransfer I noticed that S3Hook.get_key was returning a dictionary, rather then the S3.Object as mentioned in it's doc string. The important thing that was missing was ability to get the key name from the return a call to get_wildcard_key. Closes #2795 from ashb/AIRFLOW-1795-s3hook_boto3_fixes --- UPDATING.md | 13 ++++++- airflow/hooks/S3_hook.py | 14 ++++---- airflow/operators/redshift_to_s3_operator.py | 10 +++--- .../operators/s3_file_transform_operator.py | 20 +++++------ airflow/operators/s3_to_hive_operator.py | 15 ++++---- airflow/operators/sensors.py | 18 +++++----- tests/operators/s3_to_hive_operator.py | 34 ++++++++++--------- 7 files changed, 69 insertions(+), 55 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index 6abcaf785a..3c7d5494ac 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -6,12 +6,23 @@ assists people when migrating to a new version. ## Airflow 1.9 ### SSH Hook updates, along with new SSH Operator & SFTP Operator - SSH Hook now uses Paramiko library to create ssh client connection, instead of sub-process based ssh command execution previously (<1.9.0), so this is backward incompatible. + +SSH Hook now uses Paramiko library to create ssh client connection, instead of sub-process based ssh command execution previously (<1.9.0), so this is backward incompatible. - update SSHHook constructor - use SSHOperator class in place of SSHExecuteOperator which is removed now. Refer test_ssh_operator.py for usage info. - SFTPOperator is added to perform secure file transfer from serverA to serverB. Refer test_sftp_operator.py.py for usage info. - No updates are required if you are using ftpHook, it will continue work as is. +### S3Hook switched to use Boto3 + +The airflow.hooks.S3_hook.S3Hook has been switched to use boto3 instead of the older boto (a.k.a. boto2). This result in a few backwards incompatible changes to the following classes: S3Hook: + - the constructors no longer accepts `s3_conn_id`. It is now called `aws_conn_id`. + - the default conneciton is now "aws_default" instead of "s3_default" + - the return type of objects returned by `get_bucket` is now boto3.s3.Bucket + - the return type of `get_key`, and `get_wildcard_key` is now an boto3.S3.Object. + +If you are using any of these in your DAGs and specify a connection ID you will need to update the parameter name for the connection to "aws_conn_id": S3ToHiveTransfer, S3PrefixSensor, S3KeySensor, RedshiftToS3Transfer. + ### Logging update The logging structure of Airflow has been rewritten to make configuration easier and the logging system more transparent. diff --git a/airflow/hooks/S3_hook.py b/airflow/hooks/S3_hook.py index b16566f9bf..226b520c77 100644 --- a/airflow/hooks/S3_hook.py +++ b/airflow/hooks/S3_hook.py @@ -123,7 +123,7 @@ class S3Hook(AwsHook): def get_key(self, key, bucket_name=None): """ - Returns a boto3.S3.Key object + Returns a boto3.s3.Object :param key: the path to the key :type key: str @@ -132,8 +132,10 @@ class S3Hook(AwsHook): """ if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) - - return self.get_conn().get_object(Bucket=bucket_name, Key=key) + + obj = self.get_resource_type('s3').Object(bucket_name, key) + obj.load() + return obj def read_key(self, key, bucket_name=None): """ @@ -144,9 +146,9 @@ class S3Hook(AwsHook): :param bucket_name: Name of the bucket in which the file is stored :type bucket_name: str """ - + obj = self.get_key(key, bucket_name) - return obj['Body'].read().decode('utf-8') + return obj.get()['Body'].read().decode('utf-8') def check_for_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''): @@ -159,7 +161,7 @@ class S3Hook(AwsHook): def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''): """ - Returns a boto3.s3.Key object matching the regular expression + Returns a boto3.s3.Object object matching the regular expression :param regex_key: the path to the key :type regex_key: str diff --git a/airflow/operators/redshift_to_s3_operator.py b/airflow/operators/redshift_to_s3_operator.py index 683ff9c1eb..5553a2af34 100644 --- a/airflow/operators/redshift_to_s3_operator.py +++ b/airflow/operators/redshift_to_s3_operator.py @@ -30,8 +30,8 @@ class RedshiftToS3Transfer(BaseOperator): :type s3_key: string :param redshift_conn_id: reference to a specific redshift database :type redshift_conn_id: string - :param s3_conn_id: reference to a specific S3 connection - :type s3_conn_id: string + :param aws_conn_id: reference to a specific S3 connection + :type aws_conn_id: string :param options: reference to a list of UNLOAD options :type options: list """ @@ -48,7 +48,7 @@ class RedshiftToS3Transfer(BaseOperator): s3_bucket, s3_key, redshift_conn_id='redshift_default', - s3_conn_id='s3_default', + aws_conn_id='aws_default', unload_options=tuple(), autocommit=False, parameters=None, @@ -59,14 +59,14 @@ class RedshiftToS3Transfer(BaseOperator): self.s3_bucket = s3_bucket self.s3_key = s3_key self.redshift_conn_id = redshift_conn_id - self.s3_conn_id = s3_conn_id + self.aws_conn_id = aws_conn_id self.unload_options = unload_options self.autocommit = autocommit self.parameters = parameters def execute(self, context): self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) - self.s3 = S3Hook(s3_conn_id=self.s3_conn_id) + self.s3 = S3Hook(aws_conn_id=self.aws_conn_id) a_key, s_key = self.s3.get_credentials() unload_options = '\n\t\t\t'.join(self.unload_options) diff --git a/airflow/operators/s3_file_transform_operator.py b/airflow/operators/s3_file_transform_operator.py index 68c733cc7e..e105e3d23a 100644 --- a/airflow/operators/s3_file_transform_operator.py +++ b/airflow/operators/s3_file_transform_operator.py @@ -37,12 +37,12 @@ class S3FileTransformOperator(BaseOperator): :param source_s3_key: The key to be retrieved from S3 :type source_s3_key: str - :param source_s3_conn_id: source s3 connection - :type source_s3_conn_id: str + :param source_aws_conn_id: source s3 connection + :type source_aws_conn_id: str :param dest_s3_key: The key to be written from S3 :type dest_s3_key: str - :param dest_s3_conn_id: destination s3 connection - :type dest_s3_conn_id: str + :param dest_aws_conn_id: destination s3 connection + :type dest_aws_conn_id: str :param replace: Replace dest S3 key if it already exists :type replace: bool :param transform_script: location of the executable transformation script @@ -59,21 +59,21 @@ class S3FileTransformOperator(BaseOperator): source_s3_key, dest_s3_key, transform_script, - source_s3_conn_id='s3_default', - dest_s3_conn_id='s3_default', + source_aws_conn_id='aws_default', + dest_aws_conn_id='aws_default', replace=False, *args, **kwargs): super(S3FileTransformOperator, self).__init__(*args, **kwargs) self.source_s3_key = source_s3_key - self.source_s3_conn_id = source_s3_conn_id + self.source_aws_conn_id = source_aws_conn_id self.dest_s3_key = dest_s3_key - self.dest_s3_conn_id = dest_s3_conn_id + self.dest_aws_conn_id = dest_aws_conn_id self.replace = replace self.transform_script = transform_script def execute(self, context): - source_s3 = S3Hook(s3_conn_id=self.source_s3_conn_id) - dest_s3 = S3Hook(s3_conn_id=self.dest_s3_conn_id) + source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id) + dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id) self.log.info("Downloading source S3 file %s", self.source_s3_key) if not source_s3.check_for_key(self.source_s3_key): raise AirflowException("The source key {0} does not exist".format(self.source_s3_key)) diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py index 2b4aceb99c..148c643287 100644 --- a/airflow/operators/s3_to_hive_operator.py +++ b/airflow/operators/s3_to_hive_operator.py @@ -71,8 +71,8 @@ class S3ToHiveTransfer(BaseOperator): :type wildcard_match: bool :param delimiter: field delimiter in the file :type delimiter: str - :param s3_conn_id: source s3 connection - :type s3_conn_id: str + :param aws_conn_id: source s3 connection + :type aws_conn_id: str :param hive_cli_conn_id: destination hive connection :type hive_cli_conn_id: str :param input_compressed: Boolean to determine if file decompression is @@ -99,7 +99,7 @@ class S3ToHiveTransfer(BaseOperator): headers=False, check_headers=False, wildcard_match=False, - s3_conn_id='s3_default', + aws_conn_id='aws_default', hive_cli_conn_id='hive_cli_default', input_compressed=False, tblproperties=None, @@ -116,7 +116,7 @@ class S3ToHiveTransfer(BaseOperator): self.check_headers = check_headers self.wildcard_match = wildcard_match self.hive_cli_conn_id = hive_cli_conn_id - self.s3_conn_id = s3_conn_id + self.aws_conn_id = aws_conn_id self.input_compressed = input_compressed self.tblproperties = tblproperties @@ -127,7 +127,7 @@ class S3ToHiveTransfer(BaseOperator): def execute(self, context): # Downloading file from S3 - self.s3 = S3Hook(s3_conn_id=self.s3_conn_id) + self.s3 = S3Hook(aws_conn_id=self.aws_conn_id) self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Downloading S3 file") @@ -143,14 +143,13 @@ class S3ToHiveTransfer(BaseOperator): s3_key_object = self.s3.get_key(self.s3_key) root, file_ext = os.path.splitext(s3_key_object.key) with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\ - NamedTemporaryFile(mode="w", + NamedTemporaryFile(mode="wb", dir=tmp_dir, suffix=file_ext) as f: self.log.info("Dumping S3 key {0} contents to local file {1}" .format(s3_key_object.key, f.name)) - s3_key_object.get_contents_to_file(f) + s3_key_object.download_fileobj(f) f.flush() - self.s3.connection.close() if not self.headers: self.log.info("Loading file %s into Hive", f.name) self.hive.load_file( diff --git a/airflow/operators/sensors.py b/airflow/operators/sensors.py index da7a62f1df..bd073b885b 100644 --- a/airflow/operators/sensors.py +++ b/airflow/operators/sensors.py @@ -501,8 +501,8 @@ class S3KeySensor(BaseSensorOperator): :param wildcard_match: whether the bucket_key should be interpreted as a Unix wildcard pattern :type wildcard_match: bool - :param s3_conn_id: a reference to the s3 connection - :type s3_conn_id: str + :param aws_conn_id: a reference to the s3 connection + :type aws_conn_id: str """ template_fields = ('bucket_key', 'bucket_name') @@ -511,7 +511,7 @@ class S3KeySensor(BaseSensorOperator): self, bucket_key, bucket_name=None, wildcard_match=False, - s3_conn_id='s3_default', + aws_conn_id='aws_default', *args, **kwargs): super(S3KeySensor, self).__init__(*args, **kwargs) # Parse @@ -528,11 +528,11 @@ class S3KeySensor(BaseSensorOperator): self.bucket_name = bucket_name self.bucket_key = bucket_key self.wildcard_match = wildcard_match - self.s3_conn_id = s3_conn_id + self.aws_conn_id = aws_conn_id def poke(self, context): from airflow.hooks.S3_hook import S3Hook - hook = S3Hook(s3_conn_id=self.s3_conn_id) + hook = S3Hook(aws_conn_id=self.aws_conn_id) full_url = "s3://" + self.bucket_name + "/" + self.bucket_key self.log.info('Poking for key : {full_url}'.format(**locals())) if self.wildcard_match: @@ -565,7 +565,7 @@ class S3PrefixSensor(BaseSensorOperator): def __init__( self, bucket_name, prefix, delimiter='/', - s3_conn_id='s3_default', + aws_conn_id='aws_default', *args, **kwargs): super(S3PrefixSensor, self).__init__(*args, **kwargs) # Parse @@ -573,13 +573,13 @@ class S3PrefixSensor(BaseSensorOperator): self.prefix = prefix self.delimiter = delimiter self.full_url = "s3://" + bucket_name + '/' + prefix - self.s3_conn_id = s3_conn_id + self.aws_conn_id = aws_conn_id def poke(self, context): self.log.info('Poking for prefix : {self.prefix}\n' - 'in bucket s3://{self.bucket_name}'.format(**locals())) + 'in bucket s3://{self.bucket_name}'.format(**locals())) from airflow.hooks.S3_hook import S3Hook - hook = S3Hook(s3_conn_id=self.s3_conn_id) + hook = S3Hook(aws_conn_id=self.aws_conn_id) return hook.check_for_prefix( prefix=self.prefix, delimiter=self.delimiter, diff --git a/tests/operators/s3_to_hive_operator.py b/tests/operators/s3_to_hive_operator.py index faab11e15f..021c9c4292 100644 --- a/tests/operators/s3_to_hive_operator.py +++ b/tests/operators/s3_to_hive_operator.py @@ -32,6 +32,12 @@ import shutil import filecmp import errno +try: + import boto3 + from moto import mock_s3 +except ImportError: + mock_s3 = None + class S3ToHiveTransferTest(unittest.TestCase): @@ -128,10 +134,6 @@ class S3ToHiveTransferTest(unittest.TestCase): key = ext + "_" + ('h' if header else 'nh') return key - def _cp_file_contents(self, fn_src, fn_dest): - with open(fn_src, 'rb') as f_src, open(fn_dest, 'wb') as f_dest: - shutil.copyfileobj(f_src, f_dest) - def _check_file_equality(self, fn_1, fn_2, ext): # gz files contain mtime and filename in the header that # causes filecmp to return False even if contents are identical @@ -205,13 +207,15 @@ class S3ToHiveTransferTest(unittest.TestCase): msg="bz2 Compressed file not as expected") @unittest.skipIf(mock is None, 'mock package not present') + @unittest.skipIf(mock_s3 is None, 'moto package not present') @mock.patch('airflow.operators.s3_to_hive_operator.HiveCliHook') - @mock.patch('airflow.operators.s3_to_hive_operator.S3Hook') - def test_execute(self, mock_s3hook, mock_hiveclihook): + @mock_s3 + def test_execute(self, mock_hiveclihook): + conn = boto3.client('s3') + conn.create_bucket(Bucket='bucket') + # Testing txt, zip, bz2 files with and without header row - for test in product(['.txt', '.gz', '.bz2'], [True, False]): - ext = test[0] - has_header = test[1] + for (ext, has_header) in product(['.txt', '.gz', '.bz2'], [True, False]): self.kwargs['headers'] = has_header self.kwargs['check_headers'] = has_header logging.info("Testing {0} format {1} header". @@ -219,15 +223,13 @@ class S3ToHiveTransferTest(unittest.TestCase): ('with' if has_header else 'without')) ) self.kwargs['input_compressed'] = (False if ext == '.txt' else True) - self.kwargs['s3_key'] = self.s3_key + ext + self.kwargs['s3_key'] = 's3://bucket/' + self.s3_key + ext ip_fn = self._get_fn(ext, self.kwargs['headers']) op_fn = self._get_fn(ext, False) - # Mock s3 object returned by S3Hook - mock_s3_object = mock.Mock(key=self.kwargs['s3_key']) - mock_s3_object.get_contents_to_file.side_effect = \ - lambda dest_file: \ - self._cp_file_contents(ip_fn, dest_file.name) - mock_s3hook().get_key.return_value = mock_s3_object + + # Upload the file into the Mocked S3 bucket + conn.upload_file(ip_fn, 'bucket', self.s3_key + ext) + # file paramter to HiveCliHook.load_file is compared # against expected file oputput mock_hiveclihook().load_file.side_effect = \