From 0823d46a7f267f2e45195a175021825367938add Mon Sep 17 00:00:00 2001 From: Katsunori Kanda Date: Sat, 17 Oct 2020 02:54:08 +0900 Subject: [PATCH] Add type annotations for AWS operators and hooks (#11434) Co-authored-by: Tomek Urbaszek --- .../providers/amazon/aws/hooks/base_aws.py | 3 +- .../amazon/aws/transfers/dynamodb_to_s3.py | 12 ++--- .../amazon/aws/transfers/gcs_to_s3.py | 28 ++++++------ .../amazon/aws/transfers/glacier_to_gcs.py | 12 ++--- .../amazon/aws/transfers/google_api_to_s3.py | 44 +++++++++---------- .../amazon/aws/transfers/hive_to_dynamodb.py | 23 +++++----- .../aws/transfers/imap_attachment_to_s3.py | 20 ++++----- .../amazon/aws/transfers/mongo_to_s3.py | 31 +++++++------ .../amazon/aws/transfers/mysql_to_s3.py | 4 +- .../amazon/aws/transfers/redshift_to_s3.py | 2 +- .../amazon/aws/transfers/s3_to_redshift.py | 12 +++-- .../amazon/aws/transfers/s3_to_sftp.py | 15 +++++-- .../amazon/aws/transfers/sftp_to_s3.py | 15 +++++-- airflow/providers/google/cloud/hooks/gcs.py | 4 +- .../google/cloud/operators/bigquery.py | 7 +-- 15 files changed, 124 insertions(+), 108 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index ab011f3038..56e49b3a56 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -29,6 +29,7 @@ import logging from typing import Any, Dict, Optional, Tuple, Union import boto3 +from botocore.credentials import ReadOnlyCredentials from botocore.config import Config from cached_property import cached_property @@ -393,7 +394,7 @@ class AwsBaseHook(BaseHook): session, _ = self._get_credentials(region_name) return session - def get_credentials(self, region_name: Optional[str] = None) -> Tuple[Optional[str], Optional[str]]: + def get_credentials(self, region_name: Optional[str] = None) -> ReadOnlyCredentials: """ Get the underlying `botocore.Credentials` object. diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py index 10cadd58d5..0d95ee7861 100644 --- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py @@ -25,7 +25,7 @@ import json from copy import copy from os.path import getsize from tempfile import NamedTemporaryFile -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, IO from uuid import uuid4 from airflow.models import BaseOperator @@ -34,11 +34,11 @@ from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.utils.decorators import apply_defaults -def _convert_item_to_json_bytes(item): +def _convert_item_to_json_bytes(item: Dict[str, Any]) -> bytes: return (json.dumps(item) + '\n').encode('utf-8') -def _upload_file_to_s3(file_obj, bucket_name, s3_key_prefix): +def _upload_file_to_s3(file_obj: IO, bucket_name: str, s3_key_prefix: str) -> None: s3_client = S3Hook().get_conn() file_obj.seek(0) s3_client.upload_file( @@ -102,7 +102,7 @@ class DynamoDBToS3Operator(BaseOperator): s3_key_prefix: str = '', process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes, **kwargs, - ): + ) -> None: super().__init__(**kwargs) self.file_size = file_size self.process_func = process_func @@ -111,7 +111,7 @@ class DynamoDBToS3Operator(BaseOperator): self.s3_bucket_name = s3_bucket_name self.s3_key_prefix = s3_key_prefix - def execute(self, context): + def execute(self, context) -> None: table = AwsDynamoDBHook().get_conn().Table(self.dynamodb_table_name) scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {} err = None @@ -126,7 +126,7 @@ class DynamoDBToS3Operator(BaseOperator): _upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix) f.close() - def _scan_dynamodb_and_upload_to_s3(self, temp_file, scan_kwargs, table): + def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, table: Any) -> IO: while True: response = table.scan(**scan_kwargs) items = response['Items'] diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py index 695087e861..d474f310c2 100644 --- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py @@ -19,7 +19,7 @@ This module contains Google Cloud Storage to S3 operator. """ import warnings -from typing import Iterable, Optional, Sequence, Union, Dict +from typing import Iterable, Optional, Sequence, Union, Dict, List, cast from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -100,21 +100,21 @@ class GCSToS3Operator(BaseOperator): def __init__( self, *, # pylint: disable=too-many-arguments - bucket, - prefix=None, - delimiter=None, - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - delegate_to=None, - dest_aws_conn_id=None, - dest_s3_key=None, - dest_verify=None, - replace=False, + bucket: str, + prefix: Optional[str] = None, + delimiter: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + dest_aws_conn_id: str = 'aws_default', + dest_s3_key: str, + dest_verify: Optional[Union[str, bool]] = None, + replace: bool = False, google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, dest_s3_extra_args: Optional[Dict] = None, s3_acl_policy: Optional[str] = None, **kwargs, - ): + ) -> None: super().__init__(**kwargs) if google_cloud_storage_conn_id: @@ -139,7 +139,7 @@ class GCSToS3Operator(BaseOperator): self.dest_s3_extra_args = dest_s3_extra_args or {} self.s3_acl_policy = s3_acl_policy - def execute(self, context): + def execute(self, context) -> List[str]: # list all files in an Google Cloud Storage bucket hook = GCSHook( google_cloud_storage_conn_id=self.gcp_conn_id, @@ -183,7 +183,7 @@ class GCSToS3Operator(BaseOperator): self.log.info("Saving file to %s", dest_key) s3_hook.load_bytes( - file_bytes, key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy + cast(bytes, file_bytes), key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy ) self.log.info("All done, uploaded %d files to S3", len(files)) diff --git a/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py b/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py index fb43d0fb10..3506003181 100644 --- a/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +++ b/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py @@ -72,17 +72,17 @@ class GlacierToGCSOperator(BaseOperator): def __init__( self, *, - aws_conn_id="aws_default", - gcp_conn_id="google_cloud_default", + aws_conn_id: str = "aws_default", + gcp_conn_id: str = "google_cloud_default", vault_name: str, bucket_name: str, object_name: str, gzip: bool, - chunk_size=1024, - delegate_to=None, + chunk_size: int = 1024, + delegate_to: Optional[str] = None, google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, - ): + ) -> None: super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.gcp_conn_id = gcp_conn_id @@ -94,7 +94,7 @@ class GlacierToGCSOperator(BaseOperator): self.delegate_to = delegate_to self.impersonation_chain = google_impersonation_chain - def execute(self, context): + def execute(self, context) -> str: glacier_hook = GlacierHook(aws_conn_id=self.aws_conn_id) gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, diff --git a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py index ca17bed221..a8ee7db14d 100644 --- a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py @@ -23,8 +23,8 @@ import json import sys from typing import Optional, Sequence, Union -from airflow.models import BaseOperator -from airflow.models.xcom import MAX_XCOM_SIZE +from airflow.models import BaseOperator, TaskInstance +from airflow.models.xcom import MAX_XCOM_SIZE, XCOM_RETURN_KEY from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.google.common.hooks.discovery_api import GoogleDiscoveryApiHook from airflow.utils.decorators import apply_defaults @@ -98,20 +98,20 @@ class GoogleApiToS3Operator(BaseOperator): def __init__( self, *, - google_api_service_name, - google_api_service_version, - google_api_endpoint_path, - google_api_endpoint_params, - s3_destination_key, - google_api_response_via_xcom=None, - google_api_endpoint_params_via_xcom=None, - google_api_endpoint_params_via_xcom_task_ids=None, - google_api_pagination=False, - google_api_num_retries=0, - s3_overwrite=False, - gcp_conn_id='google_cloud_default', - delegate_to=None, - aws_conn_id='aws_default', + google_api_service_name: str, + google_api_service_version: str, + google_api_endpoint_path: str, + google_api_endpoint_params: dict, + s3_destination_key: str, + google_api_response_via_xcom: Optional[str] = None, + google_api_endpoint_params_via_xcom: Optional[str] = None, + google_api_endpoint_params_via_xcom_task_ids: Optional[str] = None, + google_api_pagination: bool = False, + google_api_num_retries: int = 0, + s3_overwrite: bool = False, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + aws_conn_id: str = 'aws_default', google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ): @@ -132,7 +132,7 @@ class GoogleApiToS3Operator(BaseOperator): self.aws_conn_id = aws_conn_id self.google_impersonation_chain = google_impersonation_chain - def execute(self, context): + def execute(self, context) -> None: """ Transfers Google APIs json data to S3. @@ -151,7 +151,7 @@ class GoogleApiToS3Operator(BaseOperator): if self.google_api_response_via_xcom: self._expose_google_api_response_via_xcom(context['task_instance'], data) - def _retrieve_data_from_google_api(self): + def _retrieve_data_from_google_api(self) -> dict: google_discovery_api_hook = GoogleDiscoveryApiHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -167,21 +167,21 @@ class GoogleApiToS3Operator(BaseOperator): ) return google_api_response - def _load_data_to_s3(self, data): + def _load_data_to_s3(self, data: dict) -> None: s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) s3_hook.load_string( string_data=json.dumps(data), key=self.s3_destination_key, replace=self.s3_overwrite ) - def _update_google_api_endpoint_params_via_xcom(self, task_instance): + def _update_google_api_endpoint_params_via_xcom(self, task_instance: TaskInstance) -> None: google_api_endpoint_params = task_instance.xcom_pull( task_ids=self.google_api_endpoint_params_via_xcom_task_ids, key=self.google_api_endpoint_params_via_xcom, ) self.google_api_endpoint_params.update(google_api_endpoint_params) - def _expose_google_api_response_via_xcom(self, task_instance, data): + def _expose_google_api_response_via_xcom(self, task_instance: TaskInstance, data: dict) -> None: if sys.getsizeof(data) < MAX_XCOM_SIZE: - task_instance.xcom_push(key=self.google_api_response_via_xcom, value=data) + task_instance.xcom_push(key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data) else: raise RuntimeError('The size of the downloaded data is too large to push to XCom!') diff --git a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py index 574d73a0d3..9a89719815 100644 --- a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +++ b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py @@ -21,6 +21,7 @@ This module contains operator to move data from Hive to DynamoDB. """ import json +from typing import Optional, Callable from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook @@ -64,18 +65,18 @@ class HiveToDynamoDBOperator(BaseOperator): def __init__( # pylint: disable=too-many-arguments self, *, - sql, - table_name, - table_keys, - pre_process=None, - pre_process_args=None, - pre_process_kwargs=None, - region_name=None, - schema='default', - hiveserver2_conn_id='hiveserver2_default', - aws_conn_id='aws_default', + sql: str, + table_name: str, + table_keys: list, + pre_process: Optional[Callable] = None, + pre_process_args: Optional[list] = None, + pre_process_kwargs: Optional[list] = None, + region_name: Optional[str] = None, + schema: str = 'default', + hiveserver2_conn_id: str = 'hiveserver2_default', + aws_conn_id: str = 'aws_default', **kwargs, - ): + ) -> None: super().__init__(**kwargs) self.sql = sql self.table_name = table_name diff --git a/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py b/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py index bf65b8f56d..b303d42aef 100644 --- a/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py @@ -57,16 +57,16 @@ class ImapAttachmentToS3Operator(BaseOperator): def __init__( self, *, - imap_attachment_name, - s3_key, - imap_check_regex=False, - imap_mail_folder='INBOX', - imap_mail_filter='All', - s3_overwrite=False, - imap_conn_id='imap_default', - s3_conn_id='aws_default', + imap_attachment_name: str, + s3_key: str, + imap_check_regex: bool = False, + imap_mail_folder: str = 'INBOX', + imap_mail_filter: str = 'All', + s3_overwrite: bool = False, + imap_conn_id: str = 'imap_default', + s3_conn_id: str = 'aws_default', **kwargs, - ): + ) -> None: super().__init__(**kwargs) self.imap_attachment_name = imap_attachment_name self.s3_key = s3_key @@ -77,7 +77,7 @@ class ImapAttachmentToS3Operator(BaseOperator): self.imap_conn_id = imap_conn_id self.s3_conn_id = s3_conn_id - def execute(self, context): + def execute(self, context) -> None: """ This function executes the transfer from the email server (via imap) into s3. diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py index b996e10ec7..58edc93ae2 100644 --- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. import json +from typing import Optional, Any, Iterable, Union, cast from bson import json_util @@ -44,16 +45,16 @@ class MongoToS3Operator(BaseOperator): def __init__( self, *, - mongo_conn_id, - s3_conn_id, - mongo_collection, - mongo_query, - s3_bucket, - s3_key, - mongo_db=None, - replace=False, + mongo_conn_id: str, + s3_conn_id: str, + mongo_collection: str, + mongo_query: Union[list, dict], + s3_bucket: str, + s3_key: str, + mongo_db: Optional[str] = None, + replace: bool = False, **kwargs, - ): + ) -> None: super().__init__(**kwargs) # Conn Ids self.mongo_conn_id = mongo_conn_id @@ -70,7 +71,7 @@ class MongoToS3Operator(BaseOperator): self.s3_key = s3_key self.replace = replace - def execute(self, context): + def execute(self, context) -> bool: """ Executed by task_instance at runtime """ @@ -80,13 +81,15 @@ class MongoToS3Operator(BaseOperator): if self.is_pipeline: results = MongoHook(self.mongo_conn_id).aggregate( mongo_collection=self.mongo_collection, - aggregate_query=self.mongo_query, + aggregate_query=cast(list, self.mongo_query), mongo_db=self.mongo_db, ) else: results = MongoHook(self.mongo_conn_id).find( - mongo_collection=self.mongo_collection, query=self.mongo_query, mongo_db=self.mongo_db + mongo_collection=self.mongo_collection, + query=cast(dict, self.mongo_query), + mongo_db=self.mongo_db, ) # Performs transform then stringifies the docs results into json format @@ -100,7 +103,7 @@ class MongoToS3Operator(BaseOperator): return True @staticmethod - def _stringify(iterable, joinable='\n'): + def _stringify(iterable: Iterable, joinable: str = '\n') -> str: """ Takes an iterable (pymongo Cursor or Array) containing dictionaries and returns a stringified version using python join @@ -108,7 +111,7 @@ class MongoToS3Operator(BaseOperator): return joinable.join([json.dumps(doc, default=json_util.default) for doc in iterable]) @staticmethod - def transform(docs): + def transform(docs: Any) -> Any: """ Processes pyMongo cursor and returns an iterable with each element being a JSON serializable dictionary diff --git a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py index ce6ed54789..33ffd59e00 100644 --- a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py @@ -101,7 +101,7 @@ class MySQLToS3Operator(BaseOperator): if "header" not in self.pd_csv_kwargs: self.pd_csv_kwargs["header"] = header - def _fix_int_dtypes(self, df): + def _fix_int_dtypes(self, df: pd.DataFrame) -> None: """ Mutate DataFrame to set dtypes for int columns containing NaN values." """ @@ -113,7 +113,7 @@ class MySQLToS3Operator(BaseOperator): # set to dtype that retains integers and supports NaNs df[col] = np.where(df[col].isnull(), None, df[col]).astype(pd.Int64Dtype) - def execute(self, context): + def execute(self, context) -> None: mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) data_df = mysql_hook.get_pandas_df(self.query) diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index fe51401fd6..9bc97936a6 100644 --- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -104,7 +104,7 @@ class RedshiftToS3Operator(BaseOperator): 'HEADER', ] - def execute(self, context): + def execute(self, context) -> None: postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 5852c46dfb..9abbe0a8ef 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -87,13 +87,11 @@ class S3ToRedshiftOperator(BaseOperator): self.verify = verify self.copy_options = copy_options or [] self.autocommit = autocommit - self._s3_hook = None - self._postgres_hook = None - def execute(self, context): - self._postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) - self._s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - credentials = self._s3_hook.get_credentials() + def execute(self, context) -> None: + postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + credentials = s3_hook.get_credentials() copy_options = '\n\t\t\t'.join(self.copy_options) copy_query = """ @@ -113,5 +111,5 @@ class S3ToRedshiftOperator(BaseOperator): ) self.log.info('Executing COPY command...') - self._postgres_hook.run(copy_query, self.autocommit) + postgres_hook.run(copy_query, self.autocommit) self.log.info("COPY command complete...") diff --git a/airflow/providers/amazon/aws/transfers/s3_to_sftp.py b/airflow/providers/amazon/aws/transfers/s3_to_sftp.py index 97b1918d7d..c94ff5f6c5 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_sftp.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_sftp.py @@ -50,8 +50,15 @@ class S3ToSFTPOperator(BaseOperator): @apply_defaults def __init__( - self, *, s3_bucket, s3_key, sftp_path, sftp_conn_id='ssh_default', s3_conn_id='aws_default', **kwargs - ): + self, + *, + s3_bucket: str, + s3_key: str, + sftp_path: str, + sftp_conn_id: str = 'ssh_default', + s3_conn_id: str = 'aws_default', + **kwargs, + ) -> None: super().__init__(**kwargs) self.sftp_conn_id = sftp_conn_id self.sftp_path = sftp_path @@ -60,12 +67,12 @@ class S3ToSFTPOperator(BaseOperator): self.s3_conn_id = s3_conn_id @staticmethod - def get_s3_key(s3_key): + def get_s3_key(s3_key: str) -> str: """This parses the correct format for S3 keys regardless of how the S3 url is passed.""" parsed_s3_key = urlparse(s3_key) return parsed_s3_key.path.lstrip('/') - def execute(self, context): + def execute(self, context) -> None: self.s3_key = self.get_s3_key(self.s3_key) ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) s3_hook = S3Hook(self.s3_conn_id) diff --git a/airflow/providers/amazon/aws/transfers/sftp_to_s3.py b/airflow/providers/amazon/aws/transfers/sftp_to_s3.py index 6cf1a734ba..482537f0b2 100644 --- a/airflow/providers/amazon/aws/transfers/sftp_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/sftp_to_s3.py @@ -50,8 +50,15 @@ class SFTPToS3Operator(BaseOperator): @apply_defaults def __init__( - self, *, s3_bucket, s3_key, sftp_path, sftp_conn_id='ssh_default', s3_conn_id='aws_default', **kwargs - ): + self, + *, + s3_bucket: str, + s3_key: str, + sftp_path: str, + sftp_conn_id: str = 'ssh_default', + s3_conn_id: str = 'aws_default', + **kwargs, + ) -> None: super().__init__(**kwargs) self.sftp_conn_id = sftp_conn_id self.sftp_path = sftp_path @@ -60,12 +67,12 @@ class SFTPToS3Operator(BaseOperator): self.s3_conn_id = s3_conn_id @staticmethod - def get_s3_key(s3_key): + def get_s3_key(s3_key: str) -> str: """This parses the correct format for S3 keys regardless of how the S3 url is passed.""" parsed_s3_key = urlparse(s3_key) return parsed_s3_key.path.lstrip('/') - def execute(self, context): + def execute(self, context) -> None: self.s3_key = self.get_s3_key(self.s3_key) ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) s3_hook = S3Hook(self.s3_conn_id) diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 1eb17ce88f..68ac78c13e 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -260,7 +260,9 @@ class GCSHook(GoogleBaseHook): destination_bucket.name, # type: ignore[attr-defined] ) - def download(self, object_name: str, bucket_name: Optional[str], filename: Optional[str] = None) -> str: + def download( + self, object_name: str, bucket_name: Optional[str], filename: Optional[str] = None + ) -> Union[str, bytes]: """ Downloads a file from Google Cloud Storage. diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 0686b33e2f..c208d2a2d8 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -947,9 +947,7 @@ class BigQueryCreateEmptyTableOperator(BaseOperator): delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - schema_fields = json.loads( - gcs_hook.download(gcs_bucket, gcs_object).decode("utf-8") # type: ignore[attr-defined] - ) # type: ignore[attr-defined] + schema_fields = json.loads(gcs_hook.download(gcs_bucket, gcs_object)) else: schema_fields = self.schema_fields @@ -1189,8 +1187,7 @@ class BigQueryCreateExternalTableOperator(BaseOperator): delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - schema_object = gcs_hook.download(self.bucket, self.schema_object) - schema_fields = json.loads(schema_object.decode("utf-8")) # type: ignore[attr-defined] + schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object)) else: schema_fields = self.schema_fields