Add type annotations for AWS operators and hooks (#11434)
Co-authored-by: Tomek Urbaszek <turbaszek@gmail.com>
This commit is contained in:
Родитель
3c10ca6504
Коммит
0823d46a7f
|
@ -29,6 +29,7 @@ import logging
|
|||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import boto3
|
||||
from botocore.credentials import ReadOnlyCredentials
|
||||
from botocore.config import Config
|
||||
from cached_property import cached_property
|
||||
|
||||
|
@ -393,7 +394,7 @@ class AwsBaseHook(BaseHook):
|
|||
session, _ = self._get_credentials(region_name)
|
||||
return session
|
||||
|
||||
def get_credentials(self, region_name: Optional[str] = None) -> Tuple[Optional[str], Optional[str]]:
|
||||
def get_credentials(self, region_name: Optional[str] = None) -> ReadOnlyCredentials:
|
||||
"""
|
||||
Get the underlying `botocore.Credentials` object.
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ import json
|
|||
from copy import copy
|
||||
from os.path import getsize
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional, IO
|
||||
from uuid import uuid4
|
||||
|
||||
from airflow.models import BaseOperator
|
||||
|
@ -34,11 +34,11 @@ from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
|||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
def _convert_item_to_json_bytes(item):
|
||||
def _convert_item_to_json_bytes(item: Dict[str, Any]) -> bytes:
|
||||
return (json.dumps(item) + '\n').encode('utf-8')
|
||||
|
||||
|
||||
def _upload_file_to_s3(file_obj, bucket_name, s3_key_prefix):
|
||||
def _upload_file_to_s3(file_obj: IO, bucket_name: str, s3_key_prefix: str) -> None:
|
||||
s3_client = S3Hook().get_conn()
|
||||
file_obj.seek(0)
|
||||
s3_client.upload_file(
|
||||
|
@ -102,7 +102,7 @@ class DynamoDBToS3Operator(BaseOperator):
|
|||
s3_key_prefix: str = '',
|
||||
process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.file_size = file_size
|
||||
self.process_func = process_func
|
||||
|
@ -111,7 +111,7 @@ class DynamoDBToS3Operator(BaseOperator):
|
|||
self.s3_bucket_name = s3_bucket_name
|
||||
self.s3_key_prefix = s3_key_prefix
|
||||
|
||||
def execute(self, context):
|
||||
def execute(self, context) -> None:
|
||||
table = AwsDynamoDBHook().get_conn().Table(self.dynamodb_table_name)
|
||||
scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
|
||||
err = None
|
||||
|
@ -126,7 +126,7 @@ class DynamoDBToS3Operator(BaseOperator):
|
|||
_upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix)
|
||||
f.close()
|
||||
|
||||
def _scan_dynamodb_and_upload_to_s3(self, temp_file, scan_kwargs, table):
|
||||
def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, table: Any) -> IO:
|
||||
while True:
|
||||
response = table.scan(**scan_kwargs)
|
||||
items = response['Items']
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
This module contains Google Cloud Storage to S3 operator.
|
||||
"""
|
||||
import warnings
|
||||
from typing import Iterable, Optional, Sequence, Union, Dict
|
||||
from typing import Iterable, Optional, Sequence, Union, Dict, List, cast
|
||||
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
|
@ -100,21 +100,21 @@ class GCSToS3Operator(BaseOperator):
|
|||
def __init__(
|
||||
self,
|
||||
*, # pylint: disable=too-many-arguments
|
||||
bucket,
|
||||
prefix=None,
|
||||
delimiter=None,
|
||||
gcp_conn_id='google_cloud_default',
|
||||
google_cloud_storage_conn_id=None,
|
||||
delegate_to=None,
|
||||
dest_aws_conn_id=None,
|
||||
dest_s3_key=None,
|
||||
dest_verify=None,
|
||||
replace=False,
|
||||
bucket: str,
|
||||
prefix: Optional[str] = None,
|
||||
delimiter: Optional[str] = None,
|
||||
gcp_conn_id: str = 'google_cloud_default',
|
||||
google_cloud_storage_conn_id: Optional[str] = None,
|
||||
delegate_to: Optional[str] = None,
|
||||
dest_aws_conn_id: str = 'aws_default',
|
||||
dest_s3_key: str,
|
||||
dest_verify: Optional[Union[str, bool]] = None,
|
||||
replace: bool = False,
|
||||
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
|
||||
dest_s3_extra_args: Optional[Dict] = None,
|
||||
s3_acl_policy: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if google_cloud_storage_conn_id:
|
||||
|
@ -139,7 +139,7 @@ class GCSToS3Operator(BaseOperator):
|
|||
self.dest_s3_extra_args = dest_s3_extra_args or {}
|
||||
self.s3_acl_policy = s3_acl_policy
|
||||
|
||||
def execute(self, context):
|
||||
def execute(self, context) -> List[str]:
|
||||
# list all files in an Google Cloud Storage bucket
|
||||
hook = GCSHook(
|
||||
google_cloud_storage_conn_id=self.gcp_conn_id,
|
||||
|
@ -183,7 +183,7 @@ class GCSToS3Operator(BaseOperator):
|
|||
self.log.info("Saving file to %s", dest_key)
|
||||
|
||||
s3_hook.load_bytes(
|
||||
file_bytes, key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy
|
||||
cast(bytes, file_bytes), key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy
|
||||
)
|
||||
|
||||
self.log.info("All done, uploaded %d files to S3", len(files))
|
||||
|
|
|
@ -72,17 +72,17 @@ class GlacierToGCSOperator(BaseOperator):
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aws_conn_id="aws_default",
|
||||
gcp_conn_id="google_cloud_default",
|
||||
aws_conn_id: str = "aws_default",
|
||||
gcp_conn_id: str = "google_cloud_default",
|
||||
vault_name: str,
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
gzip: bool,
|
||||
chunk_size=1024,
|
||||
delegate_to=None,
|
||||
chunk_size: int = 1024,
|
||||
delegate_to: Optional[str] = None,
|
||||
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.gcp_conn_id = gcp_conn_id
|
||||
|
@ -94,7 +94,7 @@ class GlacierToGCSOperator(BaseOperator):
|
|||
self.delegate_to = delegate_to
|
||||
self.impersonation_chain = google_impersonation_chain
|
||||
|
||||
def execute(self, context):
|
||||
def execute(self, context) -> str:
|
||||
glacier_hook = GlacierHook(aws_conn_id=self.aws_conn_id)
|
||||
gcs_hook = GCSHook(
|
||||
gcp_conn_id=self.gcp_conn_id,
|
||||
|
|
|
@ -23,8 +23,8 @@ import json
|
|||
import sys
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.models.xcom import MAX_XCOM_SIZE
|
||||
from airflow.models import BaseOperator, TaskInstance
|
||||
from airflow.models.xcom import MAX_XCOM_SIZE, XCOM_RETURN_KEY
|
||||
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
||||
from airflow.providers.google.common.hooks.discovery_api import GoogleDiscoveryApiHook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
@ -98,20 +98,20 @@ class GoogleApiToS3Operator(BaseOperator):
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
google_api_service_name,
|
||||
google_api_service_version,
|
||||
google_api_endpoint_path,
|
||||
google_api_endpoint_params,
|
||||
s3_destination_key,
|
||||
google_api_response_via_xcom=None,
|
||||
google_api_endpoint_params_via_xcom=None,
|
||||
google_api_endpoint_params_via_xcom_task_ids=None,
|
||||
google_api_pagination=False,
|
||||
google_api_num_retries=0,
|
||||
s3_overwrite=False,
|
||||
gcp_conn_id='google_cloud_default',
|
||||
delegate_to=None,
|
||||
aws_conn_id='aws_default',
|
||||
google_api_service_name: str,
|
||||
google_api_service_version: str,
|
||||
google_api_endpoint_path: str,
|
||||
google_api_endpoint_params: dict,
|
||||
s3_destination_key: str,
|
||||
google_api_response_via_xcom: Optional[str] = None,
|
||||
google_api_endpoint_params_via_xcom: Optional[str] = None,
|
||||
google_api_endpoint_params_via_xcom_task_ids: Optional[str] = None,
|
||||
google_api_pagination: bool = False,
|
||||
google_api_num_retries: int = 0,
|
||||
s3_overwrite: bool = False,
|
||||
gcp_conn_id: str = 'google_cloud_default',
|
||||
delegate_to: Optional[str] = None,
|
||||
aws_conn_id: str = 'aws_default',
|
||||
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -132,7 +132,7 @@ class GoogleApiToS3Operator(BaseOperator):
|
|||
self.aws_conn_id = aws_conn_id
|
||||
self.google_impersonation_chain = google_impersonation_chain
|
||||
|
||||
def execute(self, context):
|
||||
def execute(self, context) -> None:
|
||||
"""
|
||||
Transfers Google APIs json data to S3.
|
||||
|
||||
|
@ -151,7 +151,7 @@ class GoogleApiToS3Operator(BaseOperator):
|
|||
if self.google_api_response_via_xcom:
|
||||
self._expose_google_api_response_via_xcom(context['task_instance'], data)
|
||||
|
||||
def _retrieve_data_from_google_api(self):
|
||||
def _retrieve_data_from_google_api(self) -> dict:
|
||||
google_discovery_api_hook = GoogleDiscoveryApiHook(
|
||||
gcp_conn_id=self.gcp_conn_id,
|
||||
delegate_to=self.delegate_to,
|
||||
|
@ -167,21 +167,21 @@ class GoogleApiToS3Operator(BaseOperator):
|
|||
)
|
||||
return google_api_response
|
||||
|
||||
def _load_data_to_s3(self, data):
|
||||
def _load_data_to_s3(self, data: dict) -> None:
|
||||
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
|
||||
s3_hook.load_string(
|
||||
string_data=json.dumps(data), key=self.s3_destination_key, replace=self.s3_overwrite
|
||||
)
|
||||
|
||||
def _update_google_api_endpoint_params_via_xcom(self, task_instance):
|
||||
def _update_google_api_endpoint_params_via_xcom(self, task_instance: TaskInstance) -> None:
|
||||
google_api_endpoint_params = task_instance.xcom_pull(
|
||||
task_ids=self.google_api_endpoint_params_via_xcom_task_ids,
|
||||
key=self.google_api_endpoint_params_via_xcom,
|
||||
)
|
||||
self.google_api_endpoint_params.update(google_api_endpoint_params)
|
||||
|
||||
def _expose_google_api_response_via_xcom(self, task_instance, data):
|
||||
def _expose_google_api_response_via_xcom(self, task_instance: TaskInstance, data: dict) -> None:
|
||||
if sys.getsizeof(data) < MAX_XCOM_SIZE:
|
||||
task_instance.xcom_push(key=self.google_api_response_via_xcom, value=data)
|
||||
task_instance.xcom_push(key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data)
|
||||
else:
|
||||
raise RuntimeError('The size of the downloaded data is too large to push to XCom!')
|
||||
|
|
|
@ -21,6 +21,7 @@ This module contains operator to move data from Hive to DynamoDB.
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Callable
|
||||
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook
|
||||
|
@ -64,18 +65,18 @@ class HiveToDynamoDBOperator(BaseOperator):
|
|||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
*,
|
||||
sql,
|
||||
table_name,
|
||||
table_keys,
|
||||
pre_process=None,
|
||||
pre_process_args=None,
|
||||
pre_process_kwargs=None,
|
||||
region_name=None,
|
||||
schema='default',
|
||||
hiveserver2_conn_id='hiveserver2_default',
|
||||
aws_conn_id='aws_default',
|
||||
sql: str,
|
||||
table_name: str,
|
||||
table_keys: list,
|
||||
pre_process: Optional[Callable] = None,
|
||||
pre_process_args: Optional[list] = None,
|
||||
pre_process_kwargs: Optional[list] = None,
|
||||
region_name: Optional[str] = None,
|
||||
schema: str = 'default',
|
||||
hiveserver2_conn_id: str = 'hiveserver2_default',
|
||||
aws_conn_id: str = 'aws_default',
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.sql = sql
|
||||
self.table_name = table_name
|
||||
|
|
|
@ -57,16 +57,16 @@ class ImapAttachmentToS3Operator(BaseOperator):
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
imap_attachment_name,
|
||||
s3_key,
|
||||
imap_check_regex=False,
|
||||
imap_mail_folder='INBOX',
|
||||
imap_mail_filter='All',
|
||||
s3_overwrite=False,
|
||||
imap_conn_id='imap_default',
|
||||
s3_conn_id='aws_default',
|
||||
imap_attachment_name: str,
|
||||
s3_key: str,
|
||||
imap_check_regex: bool = False,
|
||||
imap_mail_folder: str = 'INBOX',
|
||||
imap_mail_filter: str = 'All',
|
||||
s3_overwrite: bool = False,
|
||||
imap_conn_id: str = 'imap_default',
|
||||
s3_conn_id: str = 'aws_default',
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.imap_attachment_name = imap_attachment_name
|
||||
self.s3_key = s3_key
|
||||
|
@ -77,7 +77,7 @@ class ImapAttachmentToS3Operator(BaseOperator):
|
|||
self.imap_conn_id = imap_conn_id
|
||||
self.s3_conn_id = s3_conn_id
|
||||
|
||||
def execute(self, context):
|
||||
def execute(self, context) -> None:
|
||||
"""
|
||||
This function executes the transfer from the email server (via imap) into s3.
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
from typing import Optional, Any, Iterable, Union, cast
|
||||
|
||||
from bson import json_util
|
||||
|
||||
|
@ -44,16 +45,16 @@ class MongoToS3Operator(BaseOperator):
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mongo_conn_id,
|
||||
s3_conn_id,
|
||||
mongo_collection,
|
||||
mongo_query,
|
||||
s3_bucket,
|
||||
s3_key,
|
||||
mongo_db=None,
|
||||
replace=False,
|
||||
mongo_conn_id: str,
|
||||
s3_conn_id: str,
|
||||
mongo_collection: str,
|
||||
mongo_query: Union[list, dict],
|
||||
s3_bucket: str,
|
||||
s3_key: str,
|
||||
mongo_db: Optional[str] = None,
|
||||
replace: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
# Conn Ids
|
||||
self.mongo_conn_id = mongo_conn_id
|
||||
|
@ -70,7 +71,7 @@ class MongoToS3Operator(BaseOperator):
|
|||
self.s3_key = s3_key
|
||||
self.replace = replace
|
||||
|
||||
def execute(self, context):
|
||||
def execute(self, context) -> bool:
|
||||
"""
|
||||
Executed by task_instance at runtime
|
||||
"""
|
||||
|
@ -80,13 +81,15 @@ class MongoToS3Operator(BaseOperator):
|
|||
if self.is_pipeline:
|
||||
results = MongoHook(self.mongo_conn_id).aggregate(
|
||||
mongo_collection=self.mongo_collection,
|
||||
aggregate_query=self.mongo_query,
|
||||
aggregate_query=cast(list, self.mongo_query),
|
||||
mongo_db=self.mongo_db,
|
||||
)
|
||||
|
||||
else:
|
||||
results = MongoHook(self.mongo_conn_id).find(
|
||||
mongo_collection=self.mongo_collection, query=self.mongo_query, mongo_db=self.mongo_db
|
||||
mongo_collection=self.mongo_collection,
|
||||
query=cast(dict, self.mongo_query),
|
||||
mongo_db=self.mongo_db,
|
||||
)
|
||||
|
||||
# Performs transform then stringifies the docs results into json format
|
||||
|
@ -100,7 +103,7 @@ class MongoToS3Operator(BaseOperator):
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
def _stringify(iterable, joinable='\n'):
|
||||
def _stringify(iterable: Iterable, joinable: str = '\n') -> str:
|
||||
"""
|
||||
Takes an iterable (pymongo Cursor or Array) containing dictionaries and
|
||||
returns a stringified version using python join
|
||||
|
@ -108,7 +111,7 @@ class MongoToS3Operator(BaseOperator):
|
|||
return joinable.join([json.dumps(doc, default=json_util.default) for doc in iterable])
|
||||
|
||||
@staticmethod
|
||||
def transform(docs):
|
||||
def transform(docs: Any) -> Any:
|
||||
"""
|
||||
Processes pyMongo cursor and returns an iterable with each element being
|
||||
a JSON serializable dictionary
|
||||
|
|
|
@ -101,7 +101,7 @@ class MySQLToS3Operator(BaseOperator):
|
|||
if "header" not in self.pd_csv_kwargs:
|
||||
self.pd_csv_kwargs["header"] = header
|
||||
|
||||
def _fix_int_dtypes(self, df):
|
||||
def _fix_int_dtypes(self, df: pd.DataFrame) -> None:
|
||||
"""
|
||||
Mutate DataFrame to set dtypes for int columns containing NaN values."
|
||||
"""
|
||||
|
@ -113,7 +113,7 @@ class MySQLToS3Operator(BaseOperator):
|
|||
# set to dtype that retains integers and supports NaNs
|
||||
df[col] = np.where(df[col].isnull(), None, df[col]).astype(pd.Int64Dtype)
|
||||
|
||||
def execute(self, context):
|
||||
def execute(self, context) -> None:
|
||||
mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id)
|
||||
s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
data_df = mysql_hook.get_pandas_df(self.query)
|
||||
|
|
|
@ -104,7 +104,7 @@ class RedshiftToS3Operator(BaseOperator):
|
|||
'HEADER',
|
||||
]
|
||||
|
||||
def execute(self, context):
|
||||
def execute(self, context) -> None:
|
||||
postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
|
||||
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
|
||||
|
|
|
@ -87,13 +87,11 @@ class S3ToRedshiftOperator(BaseOperator):
|
|||
self.verify = verify
|
||||
self.copy_options = copy_options or []
|
||||
self.autocommit = autocommit
|
||||
self._s3_hook = None
|
||||
self._postgres_hook = None
|
||||
|
||||
def execute(self, context):
|
||||
self._postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
|
||||
self._s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
credentials = self._s3_hook.get_credentials()
|
||||
def execute(self, context) -> None:
|
||||
postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
|
||||
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
||||
credentials = s3_hook.get_credentials()
|
||||
copy_options = '\n\t\t\t'.join(self.copy_options)
|
||||
|
||||
copy_query = """
|
||||
|
@ -113,5 +111,5 @@ class S3ToRedshiftOperator(BaseOperator):
|
|||
)
|
||||
|
||||
self.log.info('Executing COPY command...')
|
||||
self._postgres_hook.run(copy_query, self.autocommit)
|
||||
postgres_hook.run(copy_query, self.autocommit)
|
||||
self.log.info("COPY command complete...")
|
||||
|
|
|
@ -50,8 +50,15 @@ class S3ToSFTPOperator(BaseOperator):
|
|||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self, *, s3_bucket, s3_key, sftp_path, sftp_conn_id='ssh_default', s3_conn_id='aws_default', **kwargs
|
||||
):
|
||||
self,
|
||||
*,
|
||||
s3_bucket: str,
|
||||
s3_key: str,
|
||||
sftp_path: str,
|
||||
sftp_conn_id: str = 'ssh_default',
|
||||
s3_conn_id: str = 'aws_default',
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.sftp_conn_id = sftp_conn_id
|
||||
self.sftp_path = sftp_path
|
||||
|
@ -60,12 +67,12 @@ class S3ToSFTPOperator(BaseOperator):
|
|||
self.s3_conn_id = s3_conn_id
|
||||
|
||||
@staticmethod
|
||||
def get_s3_key(s3_key):
|
||||
def get_s3_key(s3_key: str) -> str:
|
||||
"""This parses the correct format for S3 keys regardless of how the S3 url is passed."""
|
||||
parsed_s3_key = urlparse(s3_key)
|
||||
return parsed_s3_key.path.lstrip('/')
|
||||
|
||||
def execute(self, context):
|
||||
def execute(self, context) -> None:
|
||||
self.s3_key = self.get_s3_key(self.s3_key)
|
||||
ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
|
||||
s3_hook = S3Hook(self.s3_conn_id)
|
||||
|
|
|
@ -50,8 +50,15 @@ class SFTPToS3Operator(BaseOperator):
|
|||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self, *, s3_bucket, s3_key, sftp_path, sftp_conn_id='ssh_default', s3_conn_id='aws_default', **kwargs
|
||||
):
|
||||
self,
|
||||
*,
|
||||
s3_bucket: str,
|
||||
s3_key: str,
|
||||
sftp_path: str,
|
||||
sftp_conn_id: str = 'ssh_default',
|
||||
s3_conn_id: str = 'aws_default',
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.sftp_conn_id = sftp_conn_id
|
||||
self.sftp_path = sftp_path
|
||||
|
@ -60,12 +67,12 @@ class SFTPToS3Operator(BaseOperator):
|
|||
self.s3_conn_id = s3_conn_id
|
||||
|
||||
@staticmethod
|
||||
def get_s3_key(s3_key):
|
||||
def get_s3_key(s3_key: str) -> str:
|
||||
"""This parses the correct format for S3 keys regardless of how the S3 url is passed."""
|
||||
parsed_s3_key = urlparse(s3_key)
|
||||
return parsed_s3_key.path.lstrip('/')
|
||||
|
||||
def execute(self, context):
|
||||
def execute(self, context) -> None:
|
||||
self.s3_key = self.get_s3_key(self.s3_key)
|
||||
ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
|
||||
s3_hook = S3Hook(self.s3_conn_id)
|
||||
|
|
|
@ -260,7 +260,9 @@ class GCSHook(GoogleBaseHook):
|
|||
destination_bucket.name, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
def download(self, object_name: str, bucket_name: Optional[str], filename: Optional[str] = None) -> str:
|
||||
def download(
|
||||
self, object_name: str, bucket_name: Optional[str], filename: Optional[str] = None
|
||||
) -> Union[str, bytes]:
|
||||
"""
|
||||
Downloads a file from Google Cloud Storage.
|
||||
|
||||
|
|
|
@ -947,9 +947,7 @@ class BigQueryCreateEmptyTableOperator(BaseOperator):
|
|||
delegate_to=self.delegate_to,
|
||||
impersonation_chain=self.impersonation_chain,
|
||||
)
|
||||
schema_fields = json.loads(
|
||||
gcs_hook.download(gcs_bucket, gcs_object).decode("utf-8") # type: ignore[attr-defined]
|
||||
) # type: ignore[attr-defined]
|
||||
schema_fields = json.loads(gcs_hook.download(gcs_bucket, gcs_object))
|
||||
else:
|
||||
schema_fields = self.schema_fields
|
||||
|
||||
|
@ -1189,8 +1187,7 @@ class BigQueryCreateExternalTableOperator(BaseOperator):
|
|||
delegate_to=self.delegate_to,
|
||||
impersonation_chain=self.impersonation_chain,
|
||||
)
|
||||
schema_object = gcs_hook.download(self.bucket, self.schema_object)
|
||||
schema_fields = json.loads(schema_object.decode("utf-8")) # type: ignore[attr-defined]
|
||||
schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object))
|
||||
else:
|
||||
schema_fields = self.schema_fields
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче