[AIRFLOW-1795] Correctly call S3Hook after migration to boto3
In the migration of S3Hook to boto3 the connection ID parameter changed to `aws_conn_id`. This fixes the uses of `s3_conn_id` in the code base and adds a note to UPDATING.md about the change. In correcting the tests for S3ToHiveTransfer I noticed that S3Hook.get_key was returning a dictionary, rather then the S3.Object as mentioned in it's doc string. The important thing that was missing was ability to get the key name from the return a call to get_wildcard_key. Closes #2795 from ashb/AIRFLOW-1795-s3hook_boto3_fixes
This commit is contained in:
Родитель
54c03f3262
Коммит
98df0d6e3b
13
UPDATING.md
13
UPDATING.md
|
@ -6,12 +6,23 @@ assists people when migrating to a new version.
|
|||
## Airflow 1.9
|
||||
|
||||
### SSH Hook updates, along with new SSH Operator & SFTP Operator
|
||||
SSH Hook now uses Paramiko library to create ssh client connection, instead of sub-process based ssh command execution previously (<1.9.0), so this is backward incompatible.
|
||||
|
||||
SSH Hook now uses Paramiko library to create ssh client connection, instead of sub-process based ssh command execution previously (<1.9.0), so this is backward incompatible.
|
||||
- update SSHHook constructor
|
||||
- use SSHOperator class in place of SSHExecuteOperator which is removed now. Refer test_ssh_operator.py for usage info.
|
||||
- SFTPOperator is added to perform secure file transfer from serverA to serverB. Refer test_sftp_operator.py.py for usage info.
|
||||
- No updates are required if you are using ftpHook, it will continue work as is.
|
||||
|
||||
### S3Hook switched to use Boto3
|
||||
|
||||
The airflow.hooks.S3_hook.S3Hook has been switched to use boto3 instead of the older boto (a.k.a. boto2). This result in a few backwards incompatible changes to the following classes: S3Hook:
|
||||
- the constructors no longer accepts `s3_conn_id`. It is now called `aws_conn_id`.
|
||||
- the default conneciton is now "aws_default" instead of "s3_default"
|
||||
- the return type of objects returned by `get_bucket` is now boto3.s3.Bucket
|
||||
- the return type of `get_key`, and `get_wildcard_key` is now an boto3.S3.Object.
|
||||
|
||||
If you are using any of these in your DAGs and specify a connection ID you will need to update the parameter name for the connection to "aws_conn_id": S3ToHiveTransfer, S3PrefixSensor, S3KeySensor, RedshiftToS3Transfer.
|
||||
|
||||
### Logging update
|
||||
|
||||
The logging structure of Airflow has been rewritten to make configuration easier and the logging system more transparent.
|
||||
|
|
|
@ -123,7 +123,7 @@ class S3Hook(AwsHook):
|
|||
|
||||
def get_key(self, key, bucket_name=None):
|
||||
"""
|
||||
Returns a boto3.S3.Key object
|
||||
Returns a boto3.s3.Object
|
||||
|
||||
:param key: the path to the key
|
||||
:type key: str
|
||||
|
@ -132,8 +132,10 @@ class S3Hook(AwsHook):
|
|||
"""
|
||||
if not bucket_name:
|
||||
(bucket_name, key) = self.parse_s3_url(key)
|
||||
|
||||
return self.get_conn().get_object(Bucket=bucket_name, Key=key)
|
||||
|
||||
obj = self.get_resource_type('s3').Object(bucket_name, key)
|
||||
obj.load()
|
||||
return obj
|
||||
|
||||
def read_key(self, key, bucket_name=None):
|
||||
"""
|
||||
|
@ -144,9 +146,9 @@ class S3Hook(AwsHook):
|
|||
:param bucket_name: Name of the bucket in which the file is stored
|
||||
:type bucket_name: str
|
||||
"""
|
||||
|
||||
|
||||
obj = self.get_key(key, bucket_name)
|
||||
return obj['Body'].read().decode('utf-8')
|
||||
return obj.get()['Body'].read().decode('utf-8')
|
||||
|
||||
def check_for_wildcard_key(self,
|
||||
wildcard_key, bucket_name=None, delimiter=''):
|
||||
|
@ -159,7 +161,7 @@ class S3Hook(AwsHook):
|
|||
|
||||
def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''):
|
||||
"""
|
||||
Returns a boto3.s3.Key object matching the regular expression
|
||||
Returns a boto3.s3.Object object matching the regular expression
|
||||
|
||||
:param regex_key: the path to the key
|
||||
:type regex_key: str
|
||||
|
|
|
@ -30,8 +30,8 @@ class RedshiftToS3Transfer(BaseOperator):
|
|||
:type s3_key: string
|
||||
:param redshift_conn_id: reference to a specific redshift database
|
||||
:type redshift_conn_id: string
|
||||
:param s3_conn_id: reference to a specific S3 connection
|
||||
:type s3_conn_id: string
|
||||
:param aws_conn_id: reference to a specific S3 connection
|
||||
:type aws_conn_id: string
|
||||
:param options: reference to a list of UNLOAD options
|
||||
:type options: list
|
||||
"""
|
||||
|
@ -48,7 +48,7 @@ class RedshiftToS3Transfer(BaseOperator):
|
|||
s3_bucket,
|
||||
s3_key,
|
||||
redshift_conn_id='redshift_default',
|
||||
s3_conn_id='s3_default',
|
||||
aws_conn_id='aws_default',
|
||||
unload_options=tuple(),
|
||||
autocommit=False,
|
||||
parameters=None,
|
||||
|
@ -59,14 +59,14 @@ class RedshiftToS3Transfer(BaseOperator):
|
|||
self.s3_bucket = s3_bucket
|
||||
self.s3_key = s3_key
|
||||
self.redshift_conn_id = redshift_conn_id
|
||||
self.s3_conn_id = s3_conn_id
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.unload_options = unload_options
|
||||
self.autocommit = autocommit
|
||||
self.parameters = parameters
|
||||
|
||||
def execute(self, context):
|
||||
self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
|
||||
self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
|
||||
self.s3 = S3Hook(aws_conn_id=self.aws_conn_id)
|
||||
a_key, s_key = self.s3.get_credentials()
|
||||
unload_options = '\n\t\t\t'.join(self.unload_options)
|
||||
|
||||
|
|
|
@ -37,12 +37,12 @@ class S3FileTransformOperator(BaseOperator):
|
|||
|
||||
:param source_s3_key: The key to be retrieved from S3
|
||||
:type source_s3_key: str
|
||||
:param source_s3_conn_id: source s3 connection
|
||||
:type source_s3_conn_id: str
|
||||
:param source_aws_conn_id: source s3 connection
|
||||
:type source_aws_conn_id: str
|
||||
:param dest_s3_key: The key to be written from S3
|
||||
:type dest_s3_key: str
|
||||
:param dest_s3_conn_id: destination s3 connection
|
||||
:type dest_s3_conn_id: str
|
||||
:param dest_aws_conn_id: destination s3 connection
|
||||
:type dest_aws_conn_id: str
|
||||
:param replace: Replace dest S3 key if it already exists
|
||||
:type replace: bool
|
||||
:param transform_script: location of the executable transformation script
|
||||
|
@ -59,21 +59,21 @@ class S3FileTransformOperator(BaseOperator):
|
|||
source_s3_key,
|
||||
dest_s3_key,
|
||||
transform_script,
|
||||
source_s3_conn_id='s3_default',
|
||||
dest_s3_conn_id='s3_default',
|
||||
source_aws_conn_id='aws_default',
|
||||
dest_aws_conn_id='aws_default',
|
||||
replace=False,
|
||||
*args, **kwargs):
|
||||
super(S3FileTransformOperator, self).__init__(*args, **kwargs)
|
||||
self.source_s3_key = source_s3_key
|
||||
self.source_s3_conn_id = source_s3_conn_id
|
||||
self.source_aws_conn_id = source_aws_conn_id
|
||||
self.dest_s3_key = dest_s3_key
|
||||
self.dest_s3_conn_id = dest_s3_conn_id
|
||||
self.dest_aws_conn_id = dest_aws_conn_id
|
||||
self.replace = replace
|
||||
self.transform_script = transform_script
|
||||
|
||||
def execute(self, context):
|
||||
source_s3 = S3Hook(s3_conn_id=self.source_s3_conn_id)
|
||||
dest_s3 = S3Hook(s3_conn_id=self.dest_s3_conn_id)
|
||||
source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id)
|
||||
dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id)
|
||||
self.log.info("Downloading source S3 file %s", self.source_s3_key)
|
||||
if not source_s3.check_for_key(self.source_s3_key):
|
||||
raise AirflowException("The source key {0} does not exist".format(self.source_s3_key))
|
||||
|
|
|
@ -71,8 +71,8 @@ class S3ToHiveTransfer(BaseOperator):
|
|||
:type wildcard_match: bool
|
||||
:param delimiter: field delimiter in the file
|
||||
:type delimiter: str
|
||||
:param s3_conn_id: source s3 connection
|
||||
:type s3_conn_id: str
|
||||
:param aws_conn_id: source s3 connection
|
||||
:type aws_conn_id: str
|
||||
:param hive_cli_conn_id: destination hive connection
|
||||
:type hive_cli_conn_id: str
|
||||
:param input_compressed: Boolean to determine if file decompression is
|
||||
|
@ -99,7 +99,7 @@ class S3ToHiveTransfer(BaseOperator):
|
|||
headers=False,
|
||||
check_headers=False,
|
||||
wildcard_match=False,
|
||||
s3_conn_id='s3_default',
|
||||
aws_conn_id='aws_default',
|
||||
hive_cli_conn_id='hive_cli_default',
|
||||
input_compressed=False,
|
||||
tblproperties=None,
|
||||
|
@ -116,7 +116,7 @@ class S3ToHiveTransfer(BaseOperator):
|
|||
self.check_headers = check_headers
|
||||
self.wildcard_match = wildcard_match
|
||||
self.hive_cli_conn_id = hive_cli_conn_id
|
||||
self.s3_conn_id = s3_conn_id
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.input_compressed = input_compressed
|
||||
self.tblproperties = tblproperties
|
||||
|
||||
|
@ -127,7 +127,7 @@ class S3ToHiveTransfer(BaseOperator):
|
|||
|
||||
def execute(self, context):
|
||||
# Downloading file from S3
|
||||
self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
|
||||
self.s3 = S3Hook(aws_conn_id=self.aws_conn_id)
|
||||
self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
|
||||
self.log.info("Downloading S3 file")
|
||||
|
||||
|
@ -143,14 +143,13 @@ class S3ToHiveTransfer(BaseOperator):
|
|||
s3_key_object = self.s3.get_key(self.s3_key)
|
||||
root, file_ext = os.path.splitext(s3_key_object.key)
|
||||
with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
|
||||
NamedTemporaryFile(mode="w",
|
||||
NamedTemporaryFile(mode="wb",
|
||||
dir=tmp_dir,
|
||||
suffix=file_ext) as f:
|
||||
self.log.info("Dumping S3 key {0} contents to local file {1}"
|
||||
.format(s3_key_object.key, f.name))
|
||||
s3_key_object.get_contents_to_file(f)
|
||||
s3_key_object.download_fileobj(f)
|
||||
f.flush()
|
||||
self.s3.connection.close()
|
||||
if not self.headers:
|
||||
self.log.info("Loading file %s into Hive", f.name)
|
||||
self.hive.load_file(
|
||||
|
|
|
@ -501,8 +501,8 @@ class S3KeySensor(BaseSensorOperator):
|
|||
:param wildcard_match: whether the bucket_key should be interpreted as a
|
||||
Unix wildcard pattern
|
||||
:type wildcard_match: bool
|
||||
:param s3_conn_id: a reference to the s3 connection
|
||||
:type s3_conn_id: str
|
||||
:param aws_conn_id: a reference to the s3 connection
|
||||
:type aws_conn_id: str
|
||||
"""
|
||||
template_fields = ('bucket_key', 'bucket_name')
|
||||
|
||||
|
@ -511,7 +511,7 @@ class S3KeySensor(BaseSensorOperator):
|
|||
self, bucket_key,
|
||||
bucket_name=None,
|
||||
wildcard_match=False,
|
||||
s3_conn_id='s3_default',
|
||||
aws_conn_id='aws_default',
|
||||
*args, **kwargs):
|
||||
super(S3KeySensor, self).__init__(*args, **kwargs)
|
||||
# Parse
|
||||
|
@ -528,11 +528,11 @@ class S3KeySensor(BaseSensorOperator):
|
|||
self.bucket_name = bucket_name
|
||||
self.bucket_key = bucket_key
|
||||
self.wildcard_match = wildcard_match
|
||||
self.s3_conn_id = s3_conn_id
|
||||
self.aws_conn_id = aws_conn_id
|
||||
|
||||
def poke(self, context):
|
||||
from airflow.hooks.S3_hook import S3Hook
|
||||
hook = S3Hook(s3_conn_id=self.s3_conn_id)
|
||||
hook = S3Hook(aws_conn_id=self.aws_conn_id)
|
||||
full_url = "s3://" + self.bucket_name + "/" + self.bucket_key
|
||||
self.log.info('Poking for key : {full_url}'.format(**locals()))
|
||||
if self.wildcard_match:
|
||||
|
@ -565,7 +565,7 @@ class S3PrefixSensor(BaseSensorOperator):
|
|||
def __init__(
|
||||
self, bucket_name,
|
||||
prefix, delimiter='/',
|
||||
s3_conn_id='s3_default',
|
||||
aws_conn_id='aws_default',
|
||||
*args, **kwargs):
|
||||
super(S3PrefixSensor, self).__init__(*args, **kwargs)
|
||||
# Parse
|
||||
|
@ -573,13 +573,13 @@ class S3PrefixSensor(BaseSensorOperator):
|
|||
self.prefix = prefix
|
||||
self.delimiter = delimiter
|
||||
self.full_url = "s3://" + bucket_name + '/' + prefix
|
||||
self.s3_conn_id = s3_conn_id
|
||||
self.aws_conn_id = aws_conn_id
|
||||
|
||||
def poke(self, context):
|
||||
self.log.info('Poking for prefix : {self.prefix}\n'
|
||||
'in bucket s3://{self.bucket_name}'.format(**locals()))
|
||||
'in bucket s3://{self.bucket_name}'.format(**locals()))
|
||||
from airflow.hooks.S3_hook import S3Hook
|
||||
hook = S3Hook(s3_conn_id=self.s3_conn_id)
|
||||
hook = S3Hook(aws_conn_id=self.aws_conn_id)
|
||||
return hook.check_for_prefix(
|
||||
prefix=self.prefix,
|
||||
delimiter=self.delimiter,
|
||||
|
|
|
@ -32,6 +32,12 @@ import shutil
|
|||
import filecmp
|
||||
import errno
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from moto import mock_s3
|
||||
except ImportError:
|
||||
mock_s3 = None
|
||||
|
||||
|
||||
class S3ToHiveTransferTest(unittest.TestCase):
|
||||
|
||||
|
@ -128,10 +134,6 @@ class S3ToHiveTransferTest(unittest.TestCase):
|
|||
key = ext + "_" + ('h' if header else 'nh')
|
||||
return key
|
||||
|
||||
def _cp_file_contents(self, fn_src, fn_dest):
|
||||
with open(fn_src, 'rb') as f_src, open(fn_dest, 'wb') as f_dest:
|
||||
shutil.copyfileobj(f_src, f_dest)
|
||||
|
||||
def _check_file_equality(self, fn_1, fn_2, ext):
|
||||
# gz files contain mtime and filename in the header that
|
||||
# causes filecmp to return False even if contents are identical
|
||||
|
@ -205,13 +207,15 @@ class S3ToHiveTransferTest(unittest.TestCase):
|
|||
msg="bz2 Compressed file not as expected")
|
||||
|
||||
@unittest.skipIf(mock is None, 'mock package not present')
|
||||
@unittest.skipIf(mock_s3 is None, 'moto package not present')
|
||||
@mock.patch('airflow.operators.s3_to_hive_operator.HiveCliHook')
|
||||
@mock.patch('airflow.operators.s3_to_hive_operator.S3Hook')
|
||||
def test_execute(self, mock_s3hook, mock_hiveclihook):
|
||||
@mock_s3
|
||||
def test_execute(self, mock_hiveclihook):
|
||||
conn = boto3.client('s3')
|
||||
conn.create_bucket(Bucket='bucket')
|
||||
|
||||
# Testing txt, zip, bz2 files with and without header row
|
||||
for test in product(['.txt', '.gz', '.bz2'], [True, False]):
|
||||
ext = test[0]
|
||||
has_header = test[1]
|
||||
for (ext, has_header) in product(['.txt', '.gz', '.bz2'], [True, False]):
|
||||
self.kwargs['headers'] = has_header
|
||||
self.kwargs['check_headers'] = has_header
|
||||
logging.info("Testing {0} format {1} header".
|
||||
|
@ -219,15 +223,13 @@ class S3ToHiveTransferTest(unittest.TestCase):
|
|||
('with' if has_header else 'without'))
|
||||
)
|
||||
self.kwargs['input_compressed'] = (False if ext == '.txt' else True)
|
||||
self.kwargs['s3_key'] = self.s3_key + ext
|
||||
self.kwargs['s3_key'] = 's3://bucket/' + self.s3_key + ext
|
||||
ip_fn = self._get_fn(ext, self.kwargs['headers'])
|
||||
op_fn = self._get_fn(ext, False)
|
||||
# Mock s3 object returned by S3Hook
|
||||
mock_s3_object = mock.Mock(key=self.kwargs['s3_key'])
|
||||
mock_s3_object.get_contents_to_file.side_effect = \
|
||||
lambda dest_file: \
|
||||
self._cp_file_contents(ip_fn, dest_file.name)
|
||||
mock_s3hook().get_key.return_value = mock_s3_object
|
||||
|
||||
# Upload the file into the Mocked S3 bucket
|
||||
conn.upload_file(ip_fn, 'bucket', self.s3_key + ext)
|
||||
|
||||
# file paramter to HiveCliHook.load_file is compared
|
||||
# against expected file oputput
|
||||
mock_hiveclihook().load_file.side_effect = \
|
||||
|
|
Загрузка…
Ссылка в новой задаче