Add type annotations for AWS operators and hooks (#11434)

Co-authored-by: Tomek Urbaszek <turbaszek@gmail.com>
This commit is contained in:
Katsunori Kanda 2020-10-17 02:54:08 +09:00 коммит произвёл GitHub
Родитель 3c10ca6504
Коммит 0823d46a7f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
15 изменённых файлов: 124 добавлений и 108 удалений

Просмотреть файл

@ -29,6 +29,7 @@ import logging
from typing import Any, Dict, Optional, Tuple, Union
import boto3
from botocore.credentials import ReadOnlyCredentials
from botocore.config import Config
from cached_property import cached_property
@ -393,7 +394,7 @@ class AwsBaseHook(BaseHook):
session, _ = self._get_credentials(region_name)
return session
def get_credentials(self, region_name: Optional[str] = None) -> Tuple[Optional[str], Optional[str]]:
def get_credentials(self, region_name: Optional[str] = None) -> ReadOnlyCredentials:
"""
Get the underlying `botocore.Credentials` object.

Просмотреть файл

@ -25,7 +25,7 @@ import json
from copy import copy
from os.path import getsize
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, IO
from uuid import uuid4
from airflow.models import BaseOperator
@ -34,11 +34,11 @@ from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.decorators import apply_defaults
def _convert_item_to_json_bytes(item):
def _convert_item_to_json_bytes(item: Dict[str, Any]) -> bytes:
return (json.dumps(item) + '\n').encode('utf-8')
def _upload_file_to_s3(file_obj, bucket_name, s3_key_prefix):
def _upload_file_to_s3(file_obj: IO, bucket_name: str, s3_key_prefix: str) -> None:
s3_client = S3Hook().get_conn()
file_obj.seek(0)
s3_client.upload_file(
@ -102,7 +102,7 @@ class DynamoDBToS3Operator(BaseOperator):
s3_key_prefix: str = '',
process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.file_size = file_size
self.process_func = process_func
@ -111,7 +111,7 @@ class DynamoDBToS3Operator(BaseOperator):
self.s3_bucket_name = s3_bucket_name
self.s3_key_prefix = s3_key_prefix
def execute(self, context):
def execute(self, context) -> None:
table = AwsDynamoDBHook().get_conn().Table(self.dynamodb_table_name)
scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
err = None
@ -126,7 +126,7 @@ class DynamoDBToS3Operator(BaseOperator):
_upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix)
f.close()
def _scan_dynamodb_and_upload_to_s3(self, temp_file, scan_kwargs, table):
def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, table: Any) -> IO:
while True:
response = table.scan(**scan_kwargs)
items = response['Items']

Просмотреть файл

@ -19,7 +19,7 @@
This module contains Google Cloud Storage to S3 operator.
"""
import warnings
from typing import Iterable, Optional, Sequence, Union, Dict
from typing import Iterable, Optional, Sequence, Union, Dict, List, cast
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@ -100,21 +100,21 @@ class GCSToS3Operator(BaseOperator):
def __init__(
self,
*, # pylint: disable=too-many-arguments
bucket,
prefix=None,
delimiter=None,
gcp_conn_id='google_cloud_default',
google_cloud_storage_conn_id=None,
delegate_to=None,
dest_aws_conn_id=None,
dest_s3_key=None,
dest_verify=None,
replace=False,
bucket: str,
prefix: Optional[str] = None,
delimiter: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
google_cloud_storage_conn_id: Optional[str] = None,
delegate_to: Optional[str] = None,
dest_aws_conn_id: str = 'aws_default',
dest_s3_key: str,
dest_verify: Optional[Union[str, bool]] = None,
replace: bool = False,
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
dest_s3_extra_args: Optional[Dict] = None,
s3_acl_policy: Optional[str] = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
if google_cloud_storage_conn_id:
@ -139,7 +139,7 @@ class GCSToS3Operator(BaseOperator):
self.dest_s3_extra_args = dest_s3_extra_args or {}
self.s3_acl_policy = s3_acl_policy
def execute(self, context):
def execute(self, context) -> List[str]:
# list all files in an Google Cloud Storage bucket
hook = GCSHook(
google_cloud_storage_conn_id=self.gcp_conn_id,
@ -183,7 +183,7 @@ class GCSToS3Operator(BaseOperator):
self.log.info("Saving file to %s", dest_key)
s3_hook.load_bytes(
file_bytes, key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy
cast(bytes, file_bytes), key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy
)
self.log.info("All done, uploaded %d files to S3", len(files))

Просмотреть файл

@ -72,17 +72,17 @@ class GlacierToGCSOperator(BaseOperator):
def __init__(
self,
*,
aws_conn_id="aws_default",
gcp_conn_id="google_cloud_default",
aws_conn_id: str = "aws_default",
gcp_conn_id: str = "google_cloud_default",
vault_name: str,
bucket_name: str,
object_name: str,
gzip: bool,
chunk_size=1024,
delegate_to=None,
chunk_size: int = 1024,
delegate_to: Optional[str] = None,
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.gcp_conn_id = gcp_conn_id
@ -94,7 +94,7 @@ class GlacierToGCSOperator(BaseOperator):
self.delegate_to = delegate_to
self.impersonation_chain = google_impersonation_chain
def execute(self, context):
def execute(self, context) -> str:
glacier_hook = GlacierHook(aws_conn_id=self.aws_conn_id)
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,

Просмотреть файл

@ -23,8 +23,8 @@ import json
import sys
from typing import Optional, Sequence, Union
from airflow.models import BaseOperator
from airflow.models.xcom import MAX_XCOM_SIZE
from airflow.models import BaseOperator, TaskInstance
from airflow.models.xcom import MAX_XCOM_SIZE, XCOM_RETURN_KEY
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.google.common.hooks.discovery_api import GoogleDiscoveryApiHook
from airflow.utils.decorators import apply_defaults
@ -98,20 +98,20 @@ class GoogleApiToS3Operator(BaseOperator):
def __init__(
self,
*,
google_api_service_name,
google_api_service_version,
google_api_endpoint_path,
google_api_endpoint_params,
s3_destination_key,
google_api_response_via_xcom=None,
google_api_endpoint_params_via_xcom=None,
google_api_endpoint_params_via_xcom_task_ids=None,
google_api_pagination=False,
google_api_num_retries=0,
s3_overwrite=False,
gcp_conn_id='google_cloud_default',
delegate_to=None,
aws_conn_id='aws_default',
google_api_service_name: str,
google_api_service_version: str,
google_api_endpoint_path: str,
google_api_endpoint_params: dict,
s3_destination_key: str,
google_api_response_via_xcom: Optional[str] = None,
google_api_endpoint_params_via_xcom: Optional[str] = None,
google_api_endpoint_params_via_xcom_task_ids: Optional[str] = None,
google_api_pagination: bool = False,
google_api_num_retries: int = 0,
s3_overwrite: bool = False,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
aws_conn_id: str = 'aws_default',
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
@ -132,7 +132,7 @@ class GoogleApiToS3Operator(BaseOperator):
self.aws_conn_id = aws_conn_id
self.google_impersonation_chain = google_impersonation_chain
def execute(self, context):
def execute(self, context) -> None:
"""
Transfers Google APIs json data to S3.
@ -151,7 +151,7 @@ class GoogleApiToS3Operator(BaseOperator):
if self.google_api_response_via_xcom:
self._expose_google_api_response_via_xcom(context['task_instance'], data)
def _retrieve_data_from_google_api(self):
def _retrieve_data_from_google_api(self) -> dict:
google_discovery_api_hook = GoogleDiscoveryApiHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@ -167,21 +167,21 @@ class GoogleApiToS3Operator(BaseOperator):
)
return google_api_response
def _load_data_to_s3(self, data):
def _load_data_to_s3(self, data: dict) -> None:
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
s3_hook.load_string(
string_data=json.dumps(data), key=self.s3_destination_key, replace=self.s3_overwrite
)
def _update_google_api_endpoint_params_via_xcom(self, task_instance):
def _update_google_api_endpoint_params_via_xcom(self, task_instance: TaskInstance) -> None:
google_api_endpoint_params = task_instance.xcom_pull(
task_ids=self.google_api_endpoint_params_via_xcom_task_ids,
key=self.google_api_endpoint_params_via_xcom,
)
self.google_api_endpoint_params.update(google_api_endpoint_params)
def _expose_google_api_response_via_xcom(self, task_instance, data):
def _expose_google_api_response_via_xcom(self, task_instance: TaskInstance, data: dict) -> None:
if sys.getsizeof(data) < MAX_XCOM_SIZE:
task_instance.xcom_push(key=self.google_api_response_via_xcom, value=data)
task_instance.xcom_push(key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data)
else:
raise RuntimeError('The size of the downloaded data is too large to push to XCom!')

Просмотреть файл

@ -21,6 +21,7 @@ This module contains operator to move data from Hive to DynamoDB.
"""
import json
from typing import Optional, Callable
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook
@ -64,18 +65,18 @@ class HiveToDynamoDBOperator(BaseOperator):
def __init__( # pylint: disable=too-many-arguments
self,
*,
sql,
table_name,
table_keys,
pre_process=None,
pre_process_args=None,
pre_process_kwargs=None,
region_name=None,
schema='default',
hiveserver2_conn_id='hiveserver2_default',
aws_conn_id='aws_default',
sql: str,
table_name: str,
table_keys: list,
pre_process: Optional[Callable] = None,
pre_process_args: Optional[list] = None,
pre_process_kwargs: Optional[list] = None,
region_name: Optional[str] = None,
schema: str = 'default',
hiveserver2_conn_id: str = 'hiveserver2_default',
aws_conn_id: str = 'aws_default',
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.sql = sql
self.table_name = table_name

Просмотреть файл

@ -57,16 +57,16 @@ class ImapAttachmentToS3Operator(BaseOperator):
def __init__(
self,
*,
imap_attachment_name,
s3_key,
imap_check_regex=False,
imap_mail_folder='INBOX',
imap_mail_filter='All',
s3_overwrite=False,
imap_conn_id='imap_default',
s3_conn_id='aws_default',
imap_attachment_name: str,
s3_key: str,
imap_check_regex: bool = False,
imap_mail_folder: str = 'INBOX',
imap_mail_filter: str = 'All',
s3_overwrite: bool = False,
imap_conn_id: str = 'imap_default',
s3_conn_id: str = 'aws_default',
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.imap_attachment_name = imap_attachment_name
self.s3_key = s3_key
@ -77,7 +77,7 @@ class ImapAttachmentToS3Operator(BaseOperator):
self.imap_conn_id = imap_conn_id
self.s3_conn_id = s3_conn_id
def execute(self, context):
def execute(self, context) -> None:
"""
This function executes the transfer from the email server (via imap) into s3.

Просмотреть файл

@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import json
from typing import Optional, Any, Iterable, Union, cast
from bson import json_util
@ -44,16 +45,16 @@ class MongoToS3Operator(BaseOperator):
def __init__(
self,
*,
mongo_conn_id,
s3_conn_id,
mongo_collection,
mongo_query,
s3_bucket,
s3_key,
mongo_db=None,
replace=False,
mongo_conn_id: str,
s3_conn_id: str,
mongo_collection: str,
mongo_query: Union[list, dict],
s3_bucket: str,
s3_key: str,
mongo_db: Optional[str] = None,
replace: bool = False,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
# Conn Ids
self.mongo_conn_id = mongo_conn_id
@ -70,7 +71,7 @@ class MongoToS3Operator(BaseOperator):
self.s3_key = s3_key
self.replace = replace
def execute(self, context):
def execute(self, context) -> bool:
"""
Executed by task_instance at runtime
"""
@ -80,13 +81,15 @@ class MongoToS3Operator(BaseOperator):
if self.is_pipeline:
results = MongoHook(self.mongo_conn_id).aggregate(
mongo_collection=self.mongo_collection,
aggregate_query=self.mongo_query,
aggregate_query=cast(list, self.mongo_query),
mongo_db=self.mongo_db,
)
else:
results = MongoHook(self.mongo_conn_id).find(
mongo_collection=self.mongo_collection, query=self.mongo_query, mongo_db=self.mongo_db
mongo_collection=self.mongo_collection,
query=cast(dict, self.mongo_query),
mongo_db=self.mongo_db,
)
# Performs transform then stringifies the docs results into json format
@ -100,7 +103,7 @@ class MongoToS3Operator(BaseOperator):
return True
@staticmethod
def _stringify(iterable, joinable='\n'):
def _stringify(iterable: Iterable, joinable: str = '\n') -> str:
"""
Takes an iterable (pymongo Cursor or Array) containing dictionaries and
returns a stringified version using python join
@ -108,7 +111,7 @@ class MongoToS3Operator(BaseOperator):
return joinable.join([json.dumps(doc, default=json_util.default) for doc in iterable])
@staticmethod
def transform(docs):
def transform(docs: Any) -> Any:
"""
Processes pyMongo cursor and returns an iterable with each element being
a JSON serializable dictionary

Просмотреть файл

@ -101,7 +101,7 @@ class MySQLToS3Operator(BaseOperator):
if "header" not in self.pd_csv_kwargs:
self.pd_csv_kwargs["header"] = header
def _fix_int_dtypes(self, df):
def _fix_int_dtypes(self, df: pd.DataFrame) -> None:
"""
Mutate DataFrame to set dtypes for int columns containing NaN values."
"""
@ -113,7 +113,7 @@ class MySQLToS3Operator(BaseOperator):
# set to dtype that retains integers and supports NaNs
df[col] = np.where(df[col].isnull(), None, df[col]).astype(pd.Int64Dtype)
def execute(self, context):
def execute(self, context) -> None:
mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id)
s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
data_df = mysql_hook.get_pandas_df(self.query)

Просмотреть файл

@ -104,7 +104,7 @@ class RedshiftToS3Operator(BaseOperator):
'HEADER',
]
def execute(self, context):
def execute(self, context) -> None:
postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)

Просмотреть файл

@ -87,13 +87,11 @@ class S3ToRedshiftOperator(BaseOperator):
self.verify = verify
self.copy_options = copy_options or []
self.autocommit = autocommit
self._s3_hook = None
self._postgres_hook = None
def execute(self, context):
self._postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
self._s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
credentials = self._s3_hook.get_credentials()
def execute(self, context) -> None:
postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
credentials = s3_hook.get_credentials()
copy_options = '\n\t\t\t'.join(self.copy_options)
copy_query = """
@ -113,5 +111,5 @@ class S3ToRedshiftOperator(BaseOperator):
)
self.log.info('Executing COPY command...')
self._postgres_hook.run(copy_query, self.autocommit)
postgres_hook.run(copy_query, self.autocommit)
self.log.info("COPY command complete...")

Просмотреть файл

@ -50,8 +50,15 @@ class S3ToSFTPOperator(BaseOperator):
@apply_defaults
def __init__(
self, *, s3_bucket, s3_key, sftp_path, sftp_conn_id='ssh_default', s3_conn_id='aws_default', **kwargs
):
self,
*,
s3_bucket: str,
s3_key: str,
sftp_path: str,
sftp_conn_id: str = 'ssh_default',
s3_conn_id: str = 'aws_default',
**kwargs,
) -> None:
super().__init__(**kwargs)
self.sftp_conn_id = sftp_conn_id
self.sftp_path = sftp_path
@ -60,12 +67,12 @@ class S3ToSFTPOperator(BaseOperator):
self.s3_conn_id = s3_conn_id
@staticmethod
def get_s3_key(s3_key):
def get_s3_key(s3_key: str) -> str:
"""This parses the correct format for S3 keys regardless of how the S3 url is passed."""
parsed_s3_key = urlparse(s3_key)
return parsed_s3_key.path.lstrip('/')
def execute(self, context):
def execute(self, context) -> None:
self.s3_key = self.get_s3_key(self.s3_key)
ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
s3_hook = S3Hook(self.s3_conn_id)

Просмотреть файл

@ -50,8 +50,15 @@ class SFTPToS3Operator(BaseOperator):
@apply_defaults
def __init__(
self, *, s3_bucket, s3_key, sftp_path, sftp_conn_id='ssh_default', s3_conn_id='aws_default', **kwargs
):
self,
*,
s3_bucket: str,
s3_key: str,
sftp_path: str,
sftp_conn_id: str = 'ssh_default',
s3_conn_id: str = 'aws_default',
**kwargs,
) -> None:
super().__init__(**kwargs)
self.sftp_conn_id = sftp_conn_id
self.sftp_path = sftp_path
@ -60,12 +67,12 @@ class SFTPToS3Operator(BaseOperator):
self.s3_conn_id = s3_conn_id
@staticmethod
def get_s3_key(s3_key):
def get_s3_key(s3_key: str) -> str:
"""This parses the correct format for S3 keys regardless of how the S3 url is passed."""
parsed_s3_key = urlparse(s3_key)
return parsed_s3_key.path.lstrip('/')
def execute(self, context):
def execute(self, context) -> None:
self.s3_key = self.get_s3_key(self.s3_key)
ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
s3_hook = S3Hook(self.s3_conn_id)

Просмотреть файл

@ -260,7 +260,9 @@ class GCSHook(GoogleBaseHook):
destination_bucket.name, # type: ignore[attr-defined]
)
def download(self, object_name: str, bucket_name: Optional[str], filename: Optional[str] = None) -> str:
def download(
self, object_name: str, bucket_name: Optional[str], filename: Optional[str] = None
) -> Union[str, bytes]:
"""
Downloads a file from Google Cloud Storage.

Просмотреть файл

@ -947,9 +947,7 @@ class BigQueryCreateEmptyTableOperator(BaseOperator):
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
schema_fields = json.loads(
gcs_hook.download(gcs_bucket, gcs_object).decode("utf-8") # type: ignore[attr-defined]
) # type: ignore[attr-defined]
schema_fields = json.loads(gcs_hook.download(gcs_bucket, gcs_object))
else:
schema_fields = self.schema_fields
@ -1189,8 +1187,7 @@ class BigQueryCreateExternalTableOperator(BaseOperator):
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
schema_object = gcs_hook.download(self.bucket, self.schema_object)
schema_fields = json.loads(schema_object.decode("utf-8")) # type: ignore[attr-defined]
schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object))
else:
schema_fields = self.schema_fields