[AIRFLOW-5424] Type annotations for GCP hooks
This commit is contained in:
Родитель
8e74ad7e64
Коммит
f4a6586429
|
@ -24,9 +24,11 @@ implementation for BigQuery.
|
|||
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, NoReturn, Mapping, Union, Iterable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
from pandas import DataFrame
|
||||
from pandas_gbq.gbq import \
|
||||
_check_google_client_version as gbq_check_google_client_version
|
||||
from pandas_gbq import read_gbq
|
||||
|
@ -45,19 +47,19 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
|
|||
Interact with BigQuery. This hook uses the Google Cloud Platform
|
||||
connection.
|
||||
"""
|
||||
conn_name_attr = 'bigquery_conn_id'
|
||||
conn_name_attr = 'bigquery_conn_id' # type: str
|
||||
|
||||
def __init__(self,
|
||||
bigquery_conn_id='google_cloud_default',
|
||||
delegate_to=None,
|
||||
use_legacy_sql=True,
|
||||
location=None):
|
||||
bigquery_conn_id: str = 'google_cloud_default',
|
||||
delegate_to: Optional[str] = None,
|
||||
use_legacy_sql: bool = True,
|
||||
location: Optional[str] = None) -> None:
|
||||
super().__init__(
|
||||
gcp_conn_id=bigquery_conn_id, delegate_to=delegate_to)
|
||||
self.use_legacy_sql = use_legacy_sql
|
||||
self.location = location
|
||||
|
||||
def get_conn(self):
|
||||
def get_conn(self) -> "BigQueryConnection":
|
||||
"""
|
||||
Returns a BigQuery PEP 249 connection object.
|
||||
"""
|
||||
|
@ -70,7 +72,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
|
|||
num_retries=self.num_retries
|
||||
)
|
||||
|
||||
def get_service(self):
|
||||
def get_service(self) -> Any:
|
||||
"""
|
||||
Returns a BigQuery service object.
|
||||
"""
|
||||
|
@ -78,7 +80,9 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
|
|||
return build(
|
||||
'bigquery', 'v2', http=http_authorized, cache_discovery=False)
|
||||
|
||||
def insert_rows(self, table, rows, target_fields=None, commit_every=1000):
|
||||
def insert_rows(
|
||||
self, table: Any, rows: Any, target_fields: Any = None, commit_every: Any = 1000, replace: Any = False
|
||||
) -> NoReturn:
|
||||
"""
|
||||
Insertion is currently unsupported. Theoretically, you could use
|
||||
BigQuery's streaming API to insert rows into a table, but this hasn't
|
||||
|
@ -86,7 +90,9 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_pandas_df(self, sql, parameters=None, dialect=None):
|
||||
def get_pandas_df(
|
||||
self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None, dialect: Optional[str] = None
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Returns a Pandas DataFrame for the results produced by a BigQuery
|
||||
query. The DbApiHook method must be overridden because Pandas
|
||||
|
@ -115,7 +121,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
|
|||
verbose=False,
|
||||
credentials=credentials)
|
||||
|
||||
def table_exists(self, project_id, dataset_id, table_id):
|
||||
def table_exists(self, project_id: str, dataset_id: str, table_id: str) -> bool:
|
||||
"""
|
||||
Checks for the existence of a table in Google BigQuery.
|
||||
|
||||
|
@ -150,12 +156,9 @@ class BigQueryPandasConnector(GbqConnector):
|
|||
service account credentials into the binding.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
project_id,
|
||||
service,
|
||||
reauth=False,
|
||||
verbose=False,
|
||||
dialect='legacy'):
|
||||
def __init__(
|
||||
self, project_id: str, service: str, reauth: bool = False, verbose: bool = False, dialect="legacy"
|
||||
) -> None:
|
||||
super().__init__(project_id)
|
||||
gbq_check_google_client_version()
|
||||
gbq_test_google_api_imports()
|
||||
|
@ -173,21 +176,21 @@ class BigQueryConnection:
|
|||
work.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
""" BigQueryConnection does not have anything to close. """
|
||||
|
||||
def commit(self):
|
||||
def commit(self) -> None:
|
||||
""" BigQueryConnection does not support transactions. """
|
||||
|
||||
def cursor(self):
|
||||
def cursor(self) -> "BigQueryCursor":
|
||||
""" Return a new :py:class:`Cursor` object using the connection. """
|
||||
return BigQueryCursor(*self._args, **self._kwargs)
|
||||
|
||||
def rollback(self):
|
||||
def rollback(self) -> NoReturn:
|
||||
""" BigQueryConnection does not have transactions """
|
||||
raise NotImplementedError(
|
||||
"BigQueryConnection does not have transactions")
|
||||
|
@ -201,12 +204,12 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
service,
|
||||
project_id,
|
||||
use_legacy_sql=True,
|
||||
api_resource_configs=None,
|
||||
location=None,
|
||||
num_retries=5):
|
||||
service: Any,
|
||||
project_id: str,
|
||||
use_legacy_sql: bool = True,
|
||||
api_resource_configs: Optional[Dict] = None,
|
||||
location: Optional[str] = None,
|
||||
num_retries: int = 5) -> None:
|
||||
|
||||
self.service = service
|
||||
self.project_id = project_id
|
||||
|
@ -214,23 +217,23 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
if api_resource_configs:
|
||||
_validate_value("api_resource_configs", api_resource_configs, dict)
|
||||
self.api_resource_configs = api_resource_configs \
|
||||
if api_resource_configs else {}
|
||||
self.running_job_id = None
|
||||
if api_resource_configs else {} # type Dict
|
||||
self.running_job_id = None # type: Optional[str]
|
||||
self.location = location
|
||||
self.num_retries = num_retries
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def create_empty_table(self,
|
||||
project_id,
|
||||
dataset_id,
|
||||
table_id,
|
||||
schema_fields=None,
|
||||
time_partitioning=None,
|
||||
cluster_fields=None,
|
||||
labels=None,
|
||||
view=None,
|
||||
encryption_configuration=None,
|
||||
num_retries=5):
|
||||
project_id: str,
|
||||
dataset_id: str,
|
||||
table_id: str,
|
||||
schema_fields: Optional[List] = None,
|
||||
time_partitioning: Optional[Dict] = None,
|
||||
cluster_fields: Optional[List] = None,
|
||||
labels: Optional[Dict] = None,
|
||||
view: Optional[Dict] = None,
|
||||
encryption_configuration: Optional[Dict] = None,
|
||||
num_retries: int = 5) -> None:
|
||||
"""
|
||||
Creates a new, empty table in the dataset.
|
||||
To create a view, which is defined by a SQL query, parse a dictionary to 'view' kwarg
|
||||
|
@ -291,7 +294,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
'tableReference': {
|
||||
'tableId': table_id
|
||||
}
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
if schema_fields:
|
||||
table_resource['schema'] = {'fields': schema_fields}
|
||||
|
@ -333,23 +336,23 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
)
|
||||
|
||||
def create_external_table(self, # pylint: disable=too-many-locals,too-many-arguments
|
||||
external_project_dataset_table,
|
||||
schema_fields,
|
||||
source_uris,
|
||||
source_format='CSV',
|
||||
autodetect=False,
|
||||
compression='NONE',
|
||||
ignore_unknown_values=False,
|
||||
max_bad_records=0,
|
||||
skip_leading_rows=0,
|
||||
field_delimiter=',',
|
||||
quote_character=None,
|
||||
allow_quoted_newlines=False,
|
||||
allow_jagged_rows=False,
|
||||
src_fmt_configs=None,
|
||||
labels=None,
|
||||
encryption_configuration=None
|
||||
):
|
||||
external_project_dataset_table: str,
|
||||
schema_fields: List,
|
||||
source_uris: List,
|
||||
source_format: str = 'CSV',
|
||||
autodetect: bool = False,
|
||||
compression: str = 'NONE',
|
||||
ignore_unknown_values: bool = False,
|
||||
max_bad_records: int = 0,
|
||||
skip_leading_rows: int = 0,
|
||||
field_delimiter: str = ',',
|
||||
quote_character: Optional[str] = None,
|
||||
allow_quoted_newlines: bool = False,
|
||||
allow_jagged_rows: bool = False,
|
||||
src_fmt_configs: Optional[Dict] = None,
|
||||
labels: Optional[Dict] = None,
|
||||
encryption_configuration: Optional[Dict] = None
|
||||
) -> None:
|
||||
"""
|
||||
Creates a new external table in the dataset with the data in Google
|
||||
Cloud Storage. See here:
|
||||
|
@ -437,14 +440,14 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
allowed_formats = [
|
||||
"CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS",
|
||||
"DATASTORE_BACKUP", "PARQUET"
|
||||
]
|
||||
] # type: List[str]
|
||||
if source_format not in allowed_formats:
|
||||
raise ValueError("{0} is not a valid source format. "
|
||||
"Please use one of the following types: {1}"
|
||||
.format(source_format, allowed_formats))
|
||||
|
||||
compression = compression.upper()
|
||||
allowed_compressions = ['NONE', 'GZIP']
|
||||
allowed_compressions = ['NONE', 'GZIP'] # type: List[str]
|
||||
if compression not in allowed_compressions:
|
||||
raise ValueError("{0} is not a valid compression format. "
|
||||
"Please use one of the following types: {1}"
|
||||
|
@ -463,7 +466,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
'datasetId': dataset_id,
|
||||
'tableId': external_table_id,
|
||||
}
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
if schema_fields:
|
||||
table_resource['externalDataConfiguration'].update({
|
||||
|
@ -532,19 +535,19 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
)
|
||||
|
||||
def patch_table(self, # pylint: disable=too-many-arguments
|
||||
dataset_id,
|
||||
table_id,
|
||||
project_id=None,
|
||||
description=None,
|
||||
expiration_time=None,
|
||||
external_data_configuration=None,
|
||||
friendly_name=None,
|
||||
labels=None,
|
||||
schema=None,
|
||||
time_partitioning=None,
|
||||
view=None,
|
||||
require_partition_filter=None,
|
||||
encryption_configuration=None):
|
||||
dataset_id: str,
|
||||
table_id: str,
|
||||
project_id: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
expiration_time: Optional[int] = None,
|
||||
external_data_configuration: Optional[Dict] = None,
|
||||
friendly_name: Optional[str] = None,
|
||||
labels: Optional[Dict] = None,
|
||||
schema: Optional[List] = None,
|
||||
time_partitioning: Optional[Dict] = None,
|
||||
view: Optional[Dict] = None,
|
||||
require_partition_filter: Optional[bool] = None,
|
||||
encryption_configuration: Optional[Dict] = None) -> None:
|
||||
"""
|
||||
Patch information in an existing table.
|
||||
It only updates fileds that are provided in the request object.
|
||||
|
@ -608,7 +611,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
|
||||
project_id = project_id if project_id is not None else self.project_id
|
||||
|
||||
table_resource = {}
|
||||
table_resource = {} # type: Dict[str, Any]
|
||||
|
||||
if description is not None:
|
||||
table_resource['description'] = description
|
||||
|
@ -651,25 +654,25 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
|
||||
# pylint: disable=too-many-locals,too-many-arguments, too-many-branches
|
||||
def run_query(self,
|
||||
sql,
|
||||
destination_dataset_table=None,
|
||||
write_disposition='WRITE_EMPTY',
|
||||
allow_large_results=False,
|
||||
flatten_results=None,
|
||||
udf_config=None,
|
||||
use_legacy_sql=None,
|
||||
maximum_billing_tier=None,
|
||||
maximum_bytes_billed=None,
|
||||
create_disposition='CREATE_IF_NEEDED',
|
||||
query_params=None,
|
||||
labels=None,
|
||||
schema_update_options=None,
|
||||
priority='INTERACTIVE',
|
||||
time_partitioning=None,
|
||||
api_resource_configs=None,
|
||||
cluster_fields=None,
|
||||
location=None,
|
||||
encryption_configuration=None):
|
||||
sql: str,
|
||||
destination_dataset_table: Optional[str] = None,
|
||||
write_disposition: str = 'WRITE_EMPTY',
|
||||
allow_large_results: bool = False,
|
||||
flatten_results: Optional[bool] = None,
|
||||
udf_config: Optional[List] = None,
|
||||
use_legacy_sql: Optional[bool] = None,
|
||||
maximum_billing_tier: Optional[int] = None,
|
||||
maximum_bytes_billed: Optional[float] = None,
|
||||
create_disposition: str = 'CREATE_IF_NEEDED',
|
||||
query_params: Optional[List] = None,
|
||||
labels: Optional[Dict] = None,
|
||||
schema_update_options: Optional[Iterable] = None,
|
||||
priority: str = 'INTERACTIVE',
|
||||
time_partitioning: Optional[Dict] = None,
|
||||
api_resource_configs: Optional[Dict] = None,
|
||||
cluster_fields: Optional[List[str]] = None,
|
||||
location: Optional[str] = None,
|
||||
encryption_configuration: Optional[Dict] = None) -> str:
|
||||
"""
|
||||
Executes a BigQuery SQL query. Optionally persists results in a BigQuery
|
||||
table. See here:
|
||||
|
@ -802,14 +805,14 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
_split_tablename(table_input=destination_dataset_table,
|
||||
default_project_id=self.project_id)
|
||||
|
||||
destination_dataset_table = {
|
||||
destination_dataset_table = { # type: ignore
|
||||
'projectId': destination_project,
|
||||
'datasetId': destination_dataset,
|
||||
'tableId': destination_table,
|
||||
}
|
||||
|
||||
if cluster_fields:
|
||||
cluster_fields = {'fields': cluster_fields}
|
||||
cluster_fields = {'fields': cluster_fields} # type: ignore
|
||||
|
||||
query_param_list = [
|
||||
(sql, 'query', None, (str,)),
|
||||
|
@ -823,7 +826,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
(schema_update_options, 'schemaUpdateOptions', None, list),
|
||||
(destination_dataset_table, 'destinationTable', None, dict),
|
||||
(cluster_fields, 'clustering', None, dict),
|
||||
]
|
||||
] # type: List[Tuple]
|
||||
|
||||
for param, param_name, param_default, param_type in query_param_list:
|
||||
if param_name not in configuration['query'] and param in [None, {}, ()]:
|
||||
|
@ -888,13 +891,13 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
|
||||
def run_extract( # noqa
|
||||
self,
|
||||
source_project_dataset_table,
|
||||
destination_cloud_storage_uris,
|
||||
compression='NONE',
|
||||
export_format='CSV',
|
||||
field_delimiter=',',
|
||||
print_header=True,
|
||||
labels=None):
|
||||
source_project_dataset_table: str,
|
||||
destination_cloud_storage_uris: str,
|
||||
compression: str = 'NONE',
|
||||
export_format: str = 'CSV',
|
||||
field_delimiter: str = ',',
|
||||
print_header: bool = True,
|
||||
labels: Optional[Dict] = None) -> str:
|
||||
"""
|
||||
Executes a BigQuery extract command to copy data from BigQuery to
|
||||
Google Cloud Storage. See here:
|
||||
|
@ -940,7 +943,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
'destinationUris': destination_cloud_storage_uris,
|
||||
'destinationFormat': export_format,
|
||||
}
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
if labels:
|
||||
configuration['labels'] = labels
|
||||
|
@ -955,12 +958,12 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
return self.run_with_configuration(configuration)
|
||||
|
||||
def run_copy(self, # pylint: disable=invalid-name
|
||||
source_project_dataset_tables,
|
||||
destination_project_dataset_table,
|
||||
write_disposition='WRITE_EMPTY',
|
||||
create_disposition='CREATE_IF_NEEDED',
|
||||
labels=None,
|
||||
encryption_configuration=None):
|
||||
source_project_dataset_tables: Union[List, str],
|
||||
destination_project_dataset_table: str,
|
||||
write_disposition: str = 'WRITE_EMPTY',
|
||||
create_disposition: str = 'CREATE_IF_NEEDED',
|
||||
labels: Optional[Dict] = None,
|
||||
encryption_configuration: Optional[Dict] = None) -> str:
|
||||
"""
|
||||
Executes a BigQuery copy command to copy data from one BigQuery table
|
||||
to another. See here:
|
||||
|
@ -1041,25 +1044,25 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
return self.run_with_configuration(configuration)
|
||||
|
||||
def run_load(self, # pylint: disable=too-many-locals,too-many-arguments,invalid-name
|
||||
destination_project_dataset_table,
|
||||
source_uris,
|
||||
schema_fields=None,
|
||||
source_format='CSV',
|
||||
create_disposition='CREATE_IF_NEEDED',
|
||||
skip_leading_rows=0,
|
||||
write_disposition='WRITE_EMPTY',
|
||||
field_delimiter=',',
|
||||
max_bad_records=0,
|
||||
quote_character=None,
|
||||
ignore_unknown_values=False,
|
||||
allow_quoted_newlines=False,
|
||||
allow_jagged_rows=False,
|
||||
schema_update_options=None,
|
||||
src_fmt_configs=None,
|
||||
time_partitioning=None,
|
||||
cluster_fields=None,
|
||||
autodetect=False,
|
||||
encryption_configuration=None):
|
||||
destination_project_dataset_table: str,
|
||||
source_uris: List,
|
||||
schema_fields: Optional[List] = None,
|
||||
source_format: str = 'CSV',
|
||||
create_disposition: str = 'CREATE_IF_NEEDED',
|
||||
skip_leading_rows: int = 0,
|
||||
write_disposition: str = 'WRITE_EMPTY',
|
||||
field_delimiter: str = ',',
|
||||
max_bad_records: int = 0,
|
||||
quote_character: Optional[str] = None,
|
||||
ignore_unknown_values: bool = False,
|
||||
allow_quoted_newlines: bool = False,
|
||||
allow_jagged_rows: bool = False,
|
||||
schema_update_options: Optional[Iterable] = None,
|
||||
src_fmt_configs: Optional[Dict] = None,
|
||||
time_partitioning: Optional[Dict] = None,
|
||||
cluster_fields: Optional[List] = None,
|
||||
autodetect: bool = False,
|
||||
encryption_configuration: Optional[Dict] = None) -> str:
|
||||
"""
|
||||
Executes a BigQuery load command to load data from Google Cloud Storage
|
||||
to BigQuery. See here:
|
||||
|
@ -1234,6 +1237,19 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
"destinationEncryptionConfiguration"
|
||||
] = encryption_configuration
|
||||
|
||||
# if following fields are not specified in src_fmt_configs,
|
||||
# honor the top-level params for backward-compatibility
|
||||
if 'skipLeadingRows' not in src_fmt_configs:
|
||||
src_fmt_configs['skipLeadingRows'] = skip_leading_rows
|
||||
if 'fieldDelimiter' not in src_fmt_configs:
|
||||
src_fmt_configs['fieldDelimiter'] = field_delimiter
|
||||
if 'ignoreUnknownValues' not in src_fmt_configs:
|
||||
src_fmt_configs['ignoreUnknownValues'] = ignore_unknown_values
|
||||
if quote_character is not None:
|
||||
src_fmt_configs['quote'] = quote_character
|
||||
if allow_quoted_newlines:
|
||||
src_fmt_configs['allowQuotedNewlines'] = allow_quoted_newlines
|
||||
|
||||
src_fmt_to_configs_mapping = {
|
||||
'CSV': [
|
||||
'allowJaggedRows', 'allowQuotedNewlines', 'autodetect',
|
||||
|
@ -1266,7 +1282,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
|
||||
return self.run_with_configuration(configuration)
|
||||
|
||||
def run_with_configuration(self, configuration):
|
||||
def run_with_configuration(self, configuration: Dict) -> str:
|
||||
"""
|
||||
Executes a BigQuery SQL query. See here:
|
||||
|
||||
|
@ -1279,8 +1295,8 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
https://cloud.google.com/bigquery/docs/reference/v2/jobs for
|
||||
details.
|
||||
"""
|
||||
jobs = self.service.jobs()
|
||||
job_data = {'configuration': configuration}
|
||||
jobs = self.service.jobs() # type: Any
|
||||
job_data = {'configuration': configuration} # type: Dict[str, Dict]
|
||||
|
||||
# Send query and wait for reply.
|
||||
query_reply = jobs \
|
||||
|
@ -1293,7 +1309,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
location = self.location
|
||||
|
||||
# Wait for query to finish.
|
||||
keep_polling_job = True
|
||||
keep_polling_job = True # type: bool
|
||||
while keep_polling_job:
|
||||
try:
|
||||
keep_polling_job = self._check_query_status(jobs, keep_polling_job, location)
|
||||
|
@ -1309,9 +1325,9 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
'BigQuery job status check failed. Final error was: {}'.
|
||||
format(err.resp.status))
|
||||
|
||||
return self.running_job_id
|
||||
return self.running_job_id # type: ignore
|
||||
|
||||
def _check_query_status(self, jobs, keep_polling_job, location):
|
||||
def _check_query_status(self, jobs: Any, keep_polling_job: bool, location: str) -> bool:
|
||||
if location:
|
||||
job = jobs.get(
|
||||
projectId=self.project_id,
|
||||
|
@ -1335,7 +1351,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
time.sleep(5)
|
||||
return keep_polling_job
|
||||
|
||||
def poll_job_complete(self, job_id):
|
||||
def poll_job_complete(self, job_id: str) -> bool:
|
||||
"""
|
||||
Check if jobs completed.
|
||||
|
||||
|
@ -1365,7 +1381,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
format(err.resp.status))
|
||||
return False
|
||||
|
||||
def cancel_query(self):
|
||||
def cancel_query(self) -> None:
|
||||
"""
|
||||
Cancel all started queries that have not yet completed
|
||||
"""
|
||||
|
@ -1408,7 +1424,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
self.running_job_id)
|
||||
time.sleep(5)
|
||||
|
||||
def get_schema(self, dataset_id, table_id):
|
||||
def get_schema(self, dataset_id: str, table_id: str) -> Dict:
|
||||
"""
|
||||
Get the schema for a given datset.table.
|
||||
see https://cloud.google.com/bigquery/docs/reference/v2/tables#resource
|
||||
|
@ -1422,9 +1438,9 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
.execute(num_retries=self.num_retries)
|
||||
return tables_resource['schema']
|
||||
|
||||
def get_tabledata(self, dataset_id, table_id,
|
||||
max_results=None, selected_fields=None, page_token=None,
|
||||
start_index=None):
|
||||
def get_tabledata(self, dataset_id: str, table_id: str,
|
||||
max_results: Optional[int] = None, selected_fields: Optional[str] = None,
|
||||
page_token: Optional[str] = None, start_index: Optional[int] = None) -> Dict:
|
||||
"""
|
||||
Get the data of a given dataset.table and optionally with selected columns.
|
||||
see https://cloud.google.com/bigquery/docs/reference/v2/tabledata/list
|
||||
|
@ -1439,7 +1455,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
:param start_index: zero based index of the starting row to read.
|
||||
:return: map containing the requested rows.
|
||||
"""
|
||||
optional_params = {}
|
||||
optional_params = {} # type: Dict[str, Any]
|
||||
if max_results:
|
||||
optional_params['maxResults'] = max_results
|
||||
if selected_fields:
|
||||
|
@ -1454,8 +1470,8 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
tableId=table_id,
|
||||
**optional_params).execute(num_retries=self.num_retries))
|
||||
|
||||
def run_table_delete(self, deletion_dataset_table,
|
||||
ignore_if_missing=False):
|
||||
def run_table_delete(self, deletion_dataset_table: str,
|
||||
ignore_if_missing: bool = False) -> None:
|
||||
"""
|
||||
Delete an existing table from the dataset;
|
||||
If the table does not exist, return an error unless ignore_if_missing
|
||||
|
@ -1488,7 +1504,8 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
else:
|
||||
self.log.info('Table does not exist. Skipping.')
|
||||
|
||||
def run_table_upsert(self, dataset_id, table_resource, project_id=None):
|
||||
def run_table_upsert(self, dataset_id: str, table_resource: Dict,
|
||||
project_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
creates a new, empty table in the dataset;
|
||||
If the table already exists, update the existing table.
|
||||
|
@ -1538,11 +1555,11 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
body=table_resource).execute(num_retries=self.num_retries)
|
||||
|
||||
def run_grant_dataset_view_access(self,
|
||||
source_dataset,
|
||||
view_dataset,
|
||||
view_table,
|
||||
source_project=None,
|
||||
view_project=None):
|
||||
source_dataset: str,
|
||||
view_dataset: str,
|
||||
view_table: str,
|
||||
source_project: Optional[str] = None,
|
||||
view_project: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Grant authorized view access of a dataset to a view table.
|
||||
If this view has already been granted access to the dataset, do nothing.
|
||||
|
@ -1600,8 +1617,11 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
view_project, view_dataset, view_table, source_project, source_dataset)
|
||||
return source_dataset_resource
|
||||
|
||||
def create_empty_dataset(self, dataset_id="", project_id="",
|
||||
location=None, dataset_reference=None):
|
||||
def create_empty_dataset(self,
|
||||
dataset_id: str = "",
|
||||
project_id: str = "",
|
||||
location: Optional[str] = None,
|
||||
dataset_reference: Optional[Dict] = None) -> None:
|
||||
"""
|
||||
Create a new empty dataset:
|
||||
https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/insert
|
||||
|
@ -1659,9 +1679,8 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
'location', location,
|
||||
dataset_reference, 'dataset_reference')
|
||||
|
||||
dataset_id = dataset_reference.get("datasetReference").get("datasetId")
|
||||
dataset_project_id = dataset_reference.get("datasetReference").get(
|
||||
"projectId")
|
||||
dataset_id = dataset_reference.get("datasetReference").get("datasetId") # type: ignore
|
||||
dataset_project_id = dataset_reference.get("datasetReference").get("projectId") # type: ignore
|
||||
|
||||
self.log.info('Creating Dataset: %s in project: %s ', dataset_id,
|
||||
dataset_project_id)
|
||||
|
@ -1678,7 +1697,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
'BigQuery job failed. Error was: {}'.format(err.content)
|
||||
)
|
||||
|
||||
def delete_dataset(self, project_id, dataset_id, delete_contents=False):
|
||||
def delete_dataset(self, project_id: str, dataset_id: str, delete_contents: bool = False) -> None:
|
||||
"""
|
||||
Delete a dataset of Big query in your project.
|
||||
|
||||
|
@ -1710,7 +1729,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
'BigQuery job failed. Error was: {}'.format(err.content)
|
||||
)
|
||||
|
||||
def get_dataset(self, dataset_id, project_id=None):
|
||||
def get_dataset(self, dataset_id: str, project_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Method returns dataset_resource if dataset exist
|
||||
and raised 404 error if dataset does not exist
|
||||
|
@ -1742,7 +1761,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
|
||||
return dataset_resource
|
||||
|
||||
def get_datasets_list(self, project_id=None):
|
||||
def get_datasets_list(self, project_id: Optional[str] = None) -> List:
|
||||
"""
|
||||
Method returns full list of BigQuery datasets in the current project
|
||||
|
||||
|
@ -1790,7 +1809,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
|
||||
return datasets_list
|
||||
|
||||
def patch_dataset(self, dataset_id, dataset_resource, project_id=None):
|
||||
def patch_dataset(self, dataset_id: str, dataset_resource: str, project_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Patches information in an existing dataset.
|
||||
It only replaces fields that are provided in the submitted dataset resource.
|
||||
|
@ -1835,7 +1854,8 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
|
||||
return dataset
|
||||
|
||||
def update_dataset(self, dataset_id, dataset_resource, project_id=None):
|
||||
def update_dataset(self, dataset_id: str,
|
||||
dataset_resource: Dict, project_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Updates information in an existing dataset. The update method replaces the entire
|
||||
dataset resource, whereas the patch method only replaces fields that are provided
|
||||
|
@ -1881,9 +1901,9 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
|
||||
return dataset
|
||||
|
||||
def insert_all(self, project_id, dataset_id, table_id,
|
||||
rows, ignore_unknown_values=False,
|
||||
skip_invalid_rows=False, fail_on_error=False):
|
||||
def insert_all(self, project_id: str, dataset_id: str, table_id: str,
|
||||
rows: List, ignore_unknown_values: bool = False,
|
||||
skip_invalid_rows: bool = False, fail_on_error: bool = False) -> None:
|
||||
"""
|
||||
Method to stream data into BigQuery one record at a time without needing
|
||||
to run a load job
|
||||
|
@ -1967,7 +1987,14 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|||
https://github.com/dropbox/PyHive/blob/master/pyhive/common.py
|
||||
"""
|
||||
|
||||
def __init__(self, service, project_id, use_legacy_sql=True, location=None, num_retries=5):
|
||||
def __init__(
|
||||
self,
|
||||
service: Any,
|
||||
project_id: str,
|
||||
use_legacy_sql: bool = True,
|
||||
location: Optional[str] = None,
|
||||
num_retries: int = 5,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
service=service,
|
||||
project_id=project_id,
|
||||
|
@ -1975,26 +2002,26 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|||
location=location,
|
||||
num_retries=num_retries
|
||||
)
|
||||
self.buffersize = None
|
||||
self.page_token = None
|
||||
self.job_id = None
|
||||
self.buffer = []
|
||||
self.all_pages_loaded = False
|
||||
self.buffersize = None # type: Optional[int]
|
||||
self.page_token = None # type: Optional[str]
|
||||
self.job_id = None # type: Optional[str]
|
||||
self.buffer = [] # type: list
|
||||
self.all_pages_loaded = False # type: bool
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
def description(self) -> NoReturn:
|
||||
""" The schema description method is not currently implemented. """
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
""" By default, do nothing """
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
def rowcount(self) -> int:
|
||||
""" By default, return -1 to indicate that this is not supported. """
|
||||
return -1
|
||||
|
||||
def execute(self, operation, parameters=None):
|
||||
def execute(self, operation: str, parameters: Optional[Dict] = None) -> None:
|
||||
"""
|
||||
Executes a BigQuery query, and returns the job ID.
|
||||
|
||||
|
@ -2007,7 +2034,7 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|||
parameters) if parameters else operation
|
||||
self.job_id = self.run_query(sql)
|
||||
|
||||
def executemany(self, operation, seq_of_parameters):
|
||||
def executemany(self, operation: str, seq_of_parameters: List) -> None:
|
||||
"""
|
||||
Execute a BigQuery query multiple times with different parameters.
|
||||
|
||||
|
@ -2020,11 +2047,11 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|||
for parameters in seq_of_parameters:
|
||||
self.execute(operation, parameters)
|
||||
|
||||
def fetchone(self):
|
||||
def fetchone(self) -> Union[List, None]:
|
||||
""" Fetch the next row of a query result set. """
|
||||
return self.next()
|
||||
|
||||
def next(self):
|
||||
def next(self) -> Union[List, None]:
|
||||
"""
|
||||
Helper method for fetchone, which returns the next row from a buffer.
|
||||
If the buffer is empty, attempts to paginate through the result set for
|
||||
|
@ -2067,7 +2094,7 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|||
|
||||
return self.buffer.pop(0)
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
def fetchmany(self, size: Optional[int] = None) -> List:
|
||||
"""
|
||||
Fetch the next set of rows of a query result, returning a sequence of sequences
|
||||
(e.g. a list of tuples). An empty sequence is returned when no more rows are
|
||||
|
@ -2090,7 +2117,7 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|||
result.append(one)
|
||||
return result
|
||||
|
||||
def fetchall(self):
|
||||
def fetchall(self) -> List[List]:
|
||||
"""
|
||||
Fetch all (remaining) rows of a query result, returning them as a sequence of
|
||||
sequences (e.g. a list of tuples).
|
||||
|
@ -2104,27 +2131,27 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|||
result.append(one)
|
||||
return result
|
||||
|
||||
def get_arraysize(self):
|
||||
def get_arraysize(self) -> int:
|
||||
""" Specifies the number of rows to fetch at a time with .fetchmany() """
|
||||
return self.buffersize or 1
|
||||
|
||||
def set_arraysize(self, arraysize):
|
||||
def set_arraysize(self, arraysize: int) -> None:
|
||||
""" Specifies the number of rows to fetch at a time with .fetchmany() """
|
||||
self.buffersize = arraysize
|
||||
|
||||
arraysize = property(get_arraysize, set_arraysize)
|
||||
|
||||
def setinputsizes(self, sizes):
|
||||
def setinputsizes(self, sizes: Any) -> None:
|
||||
""" Does nothing by default """
|
||||
|
||||
def setoutputsize(self, size, column=None):
|
||||
def setoutputsize(self, size: Any, column: Any = None) -> None:
|
||||
""" Does nothing by default """
|
||||
|
||||
|
||||
def _bind_parameters(operation, parameters):
|
||||
def _bind_parameters(operation: str, parameters: Dict) -> str:
|
||||
""" Helper method that binds parameters to a SQL query. """
|
||||
# inspired by MySQL Python Connector (conversion.py)
|
||||
string_parameters = {}
|
||||
string_parameters = {} # type Dict[str, str]
|
||||
for (name, value) in parameters.items():
|
||||
if value is None:
|
||||
string_parameters[name] = 'NULL'
|
||||
|
@ -2135,7 +2162,7 @@ def _bind_parameters(operation, parameters):
|
|||
return operation % string_parameters
|
||||
|
||||
|
||||
def _escape(s):
|
||||
def _escape(s: str) -> str:
|
||||
""" Helper method that escapes parameters to a SQL query. """
|
||||
e = s
|
||||
e = e.replace('\\', '\\\\')
|
||||
|
@ -2146,7 +2173,7 @@ def _escape(s):
|
|||
return e
|
||||
|
||||
|
||||
def _bq_cast(string_field, bq_type):
|
||||
def _bq_cast(string_field: str, bq_type: str) -> Union[None, int, float, bool, str]:
|
||||
"""
|
||||
Helper method that casts a BigQuery row to the appropriate data types.
|
||||
This is useful because BigQuery returns all fields as strings.
|
||||
|
@ -2166,7 +2193,8 @@ def _bq_cast(string_field, bq_type):
|
|||
return string_field
|
||||
|
||||
|
||||
def _split_tablename(table_input, default_project_id, var_name=None):
|
||||
def _split_tablename(table_input: str, default_project_id: str,
|
||||
var_name: Optional[str] = None) -> Tuple[str, str, str]:
|
||||
|
||||
if '.' not in table_input:
|
||||
raise ValueError(
|
||||
|
@ -2231,8 +2259,9 @@ def _split_tablename(table_input, default_project_id, var_name=None):
|
|||
return project_id, dataset_id, table_id
|
||||
|
||||
|
||||
def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
|
||||
# if it is a partitioned table ($ is in the table name) add partition load option
|
||||
def _cleanse_time_partitioning(
|
||||
destination_dataset_table: Optional[str], time_partitioning_in: Optional[Dict]
|
||||
) -> Dict: # if it is a partitioned table ($ is in the table name) add partition load option
|
||||
|
||||
if time_partitioning_in is None:
|
||||
time_partitioning_in = {}
|
||||
|
@ -2244,7 +2273,7 @@ def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
|
|||
return time_partitioning_out
|
||||
|
||||
|
||||
def _validate_value(key, value, expected_type):
|
||||
def _validate_value(key: Any, value: Any, expected_type: Type) -> None:
|
||||
""" function to check expected type and raise
|
||||
error if type is not correct """
|
||||
if not isinstance(value, expected_type):
|
||||
|
@ -2252,8 +2281,8 @@ def _validate_value(key, value, expected_type):
|
|||
key, expected_type, type(value)))
|
||||
|
||||
|
||||
def _api_resource_configs_duplication_check(key, value, config_dict,
|
||||
config_dict_name='api_resource_configs'):
|
||||
def _api_resource_configs_duplication_check(key: Any, value: Any, config_dict: Dict,
|
||||
config_dict_name='api_resource_configs') -> None:
|
||||
if key in config_dict and value != config_dict[key]:
|
||||
raise ValueError("Values of {param_name} param are duplicated. "
|
||||
"{dict_name} contained {param_name} param "
|
||||
|
@ -2262,8 +2291,10 @@ def _api_resource_configs_duplication_check(key, value, config_dict,
|
|||
.format(param_name=key, dict_name=config_dict_name))
|
||||
|
||||
|
||||
def _validate_src_fmt_configs(source_format, src_fmt_configs, valid_configs,
|
||||
backward_compatibility_configs=None):
|
||||
def _validate_src_fmt_configs(source_format: str ,
|
||||
src_fmt_configs: Dict,
|
||||
valid_configs: List[str],
|
||||
backward_compatibility_configs: Optional[Dict] = None) -> Dict
|
||||
"""
|
||||
Validates the given src_fmt_configs against a valid configuration for the source format.
|
||||
Adds the backward compatiblity config to the src_fmt_configs.
|
||||
|
|
|
@ -25,7 +25,7 @@ import time
|
|||
import warnings
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
from typing import Dict, List, Union, Set, Optional
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
|
@ -111,7 +111,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
|
|||
self,
|
||||
api_version: str = 'v1',
|
||||
gcp_conn_id: str = 'google_cloud_default',
|
||||
delegate_to: str = None
|
||||
delegate_to: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__(gcp_conn_id, delegate_to)
|
||||
self.api_version = api_version
|
||||
|
@ -150,7 +150,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
|
|||
|
||||
@GoogleCloudBaseHook.fallback_to_default_project_id
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
def get_transfer_job(self, job_name: str, project_id: str = None) -> Dict:
|
||||
def get_transfer_job(self, job_name: str, project_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Gets the latest state of a long-running operation in Google Storage
|
||||
Transfer Service.
|
||||
|
@ -172,7 +172,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
|
|||
.execute(num_retries=self.num_retries)
|
||||
)
|
||||
|
||||
def list_transfer_job(self, request_filter: Dict = None, **kwargs) -> List[Dict]:
|
||||
def list_transfer_job(self, request_filter: Optional[Dict] = None, **kwargs) -> List[Dict]:
|
||||
"""
|
||||
Lists long-running operations in Google Storage Transfer
|
||||
Service that match the specified filter.
|
||||
|
@ -230,7 +230,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
|
|||
|
||||
@GoogleCloudBaseHook.fallback_to_default_project_id
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
def delete_transfer_job(self, job_name: str, project_id: str = None) -> None:
|
||||
def delete_transfer_job(self, job_name: str, project_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Deletes a transfer job. This is a soft delete. After a transfer job is
|
||||
deleted, the job and all the transfer executions are subject to garbage
|
||||
|
@ -293,7 +293,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
|
|||
)
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
def list_transfer_operations(self, request_filter: Dict = None, **kwargs) -> List[Dict]:
|
||||
def list_transfer_operations(self, request_filter: Optional[Dict] = None, **kwargs) -> List[Dict]:
|
||||
"""
|
||||
Gets an transfer operation in Google Storage Transfer Service.
|
||||
|
||||
|
@ -369,7 +369,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
|
|||
def wait_for_transfer_job(
|
||||
self,
|
||||
job: Dict,
|
||||
expected_statuses: Tuple[str] = (GcpTransferOperationStatus.SUCCESS,),
|
||||
expected_statuses: Optional[Set[str]] = None,
|
||||
timeout: Optional[Union[float, timedelta]] = None
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -388,6 +388,9 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
|
|||
:type timeout: Optional[Union[float, timedelta]]
|
||||
:rtype: None
|
||||
"""
|
||||
expected_statuses = (
|
||||
{GcpTransferOperationStatus.SUCCESS} if not expected_statuses else expected_statuses
|
||||
)
|
||||
if timeout is None:
|
||||
timeout = 60
|
||||
elif isinstance(timeout, timedelta):
|
||||
|
@ -417,7 +420,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
|
|||
@staticmethod
|
||||
def operations_contain_expected_statuses(
|
||||
operations: List[Dict],
|
||||
expected_statuses: Union[Tuple[str], str]
|
||||
expected_statuses: Union[Set[str], str]
|
||||
) -> bool:
|
||||
"""
|
||||
Checks whether the operation list has an operation with the
|
||||
|
|
|
@ -26,7 +26,7 @@ import select
|
|||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List, Callable, Any, Optional
|
||||
from typing import Dict, List, Callable, Any, Optional, Union
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
|
@ -61,7 +61,7 @@ class _DataflowJob(LoggingMixin):
|
|||
name: str,
|
||||
location: str,
|
||||
poll_sleep: int = 10,
|
||||
job_id: str = None,
|
||||
job_id: Optional[str] = None,
|
||||
num_retries: int = 0,
|
||||
multiple_jobs: bool = False
|
||||
) -> None:
|
||||
|
@ -75,7 +75,7 @@ class _DataflowJob(LoggingMixin):
|
|||
self._poll_sleep = poll_sleep
|
||||
self._jobs = self._get_jobs()
|
||||
|
||||
def is_job_running(self):
|
||||
def is_job_running(self) -> bool:
|
||||
"""
|
||||
Helper method to check if jos is still running in dataflow
|
||||
|
||||
|
@ -88,7 +88,7 @@ class _DataflowJob(LoggingMixin):
|
|||
return False
|
||||
|
||||
# pylint: disable=too-many-nested-blocks
|
||||
def _get_dataflow_jobs(self):
|
||||
def _get_dataflow_jobs(self) -> List:
|
||||
"""
|
||||
Helper method to get list of jobs that start with job name or id
|
||||
|
||||
|
@ -116,7 +116,7 @@ class _DataflowJob(LoggingMixin):
|
|||
else:
|
||||
raise Exception('Missing both dataflow job ID and name.')
|
||||
|
||||
def _get_jobs(self):
|
||||
def _get_jobs(self) -> List:
|
||||
"""
|
||||
Helper method to get all jobs by name
|
||||
|
||||
|
@ -145,7 +145,7 @@ class _DataflowJob(LoggingMixin):
|
|||
return self._jobs
|
||||
|
||||
# pylint: disable=too-many-nested-blocks
|
||||
def check_dataflow_job_state(self, job):
|
||||
def check_dataflow_job_state(self, job) -> bool:
|
||||
"""
|
||||
Helper method to check the state of all jobs in dataflow for this task
|
||||
if job failed raise exception
|
||||
|
@ -209,7 +209,7 @@ class _DataflowJob(LoggingMixin):
|
|||
|
||||
|
||||
class _Dataflow(LoggingMixin):
|
||||
def __init__(self, cmd) -> None:
|
||||
def __init__(self, cmd: Union[List, str]) -> None:
|
||||
self.log.info("Running command: %s", ' '.join(cmd))
|
||||
self._proc = subprocess.Popen(
|
||||
cmd,
|
||||
|
@ -297,7 +297,7 @@ class DataFlowHook(GoogleCloudBaseHook):
|
|||
def __init__(
|
||||
self,
|
||||
gcp_conn_id: str = 'google_cloud_default',
|
||||
delegate_to: str = None,
|
||||
delegate_to: Optional[str] = None,
|
||||
poll_sleep: int = 10
|
||||
) -> None:
|
||||
self.poll_sleep = poll_sleep
|
||||
|
@ -328,7 +328,7 @@ class DataFlowHook(GoogleCloudBaseHook):
|
|||
.wait_for_done()
|
||||
|
||||
@staticmethod
|
||||
def _set_variables(variables: Dict):
|
||||
def _set_variables(variables: Dict) -> Dict:
|
||||
if variables['project'] is None:
|
||||
raise Exception('Project not specified')
|
||||
if 'region' not in variables.keys():
|
||||
|
@ -340,7 +340,7 @@ class DataFlowHook(GoogleCloudBaseHook):
|
|||
job_name: str,
|
||||
variables: Dict,
|
||||
jar: str,
|
||||
job_class: str = None,
|
||||
job_class: Optional[str] = None,
|
||||
append_job_name: bool = True,
|
||||
multiple_jobs: bool = False
|
||||
) -> None:
|
||||
|
@ -377,7 +377,7 @@ class DataFlowHook(GoogleCloudBaseHook):
|
|||
variables: Dict,
|
||||
parameters: Dict,
|
||||
dataflow_template: str,
|
||||
append_job_name=True
|
||||
append_job_name: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Starts Dataflow template job.
|
||||
|
|
|
@ -22,6 +22,7 @@ This module contains Google Datastore hook.
|
|||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, List, Dict, Union, Optional
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
|
@ -40,14 +41,14 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
datastore_conn_id='google_cloud_default',
|
||||
delegate_to=None,
|
||||
api_version='v1'):
|
||||
datastore_conn_id: str = 'google_cloud_default',
|
||||
delegate_to: Optional[str] = None,
|
||||
api_version: str = 'v1') -> None:
|
||||
super().__init__(datastore_conn_id, delegate_to)
|
||||
self.connection = None
|
||||
self.api_version = api_version
|
||||
|
||||
def get_conn(self):
|
||||
def get_conn(self) -> Any:
|
||||
"""
|
||||
Establishes a connection to the Google API.
|
||||
|
||||
|
@ -62,7 +63,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
return self.connection
|
||||
|
||||
@GoogleCloudBaseHook.fallback_to_default_project_id
|
||||
def allocate_ids(self, partial_keys, project_id=None):
|
||||
def allocate_ids(self, partial_keys: List, project_id: Optional[str] = None) -> List:
|
||||
"""
|
||||
Allocate IDs for incomplete keys.
|
||||
|
||||
|
@ -76,7 +77,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
:return: a list of full keys.
|
||||
:rtype: list
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn = self.get_conn() # type: Any
|
||||
|
||||
resp = (conn # pylint:disable=no-member
|
||||
.projects()
|
||||
|
@ -86,7 +87,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
return resp['keys']
|
||||
|
||||
@GoogleCloudBaseHook.fallback_to_default_project_id
|
||||
def begin_transaction(self, project_id=None):
|
||||
def begin_transaction(self, project_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Begins a new transaction.
|
||||
|
||||
|
@ -98,7 +99,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
:return: a transaction handle.
|
||||
:rtype: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn = self.get_conn() # type: Any
|
||||
|
||||
resp = (conn # pylint:disable=no-member
|
||||
.projects()
|
||||
|
@ -108,7 +109,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
return resp['transaction']
|
||||
|
||||
@GoogleCloudBaseHook.fallback_to_default_project_id
|
||||
def commit(self, body, project_id=None):
|
||||
def commit(self, body: Dict, project_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Commit a transaction, optionally creating, deleting or modifying some entities.
|
||||
|
||||
|
@ -122,7 +123,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
:return: the response body of the commit request.
|
||||
:rtype: dict
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn = self.get_conn() # type: Any
|
||||
|
||||
resp = (conn # pylint:disable=no-member
|
||||
.projects()
|
||||
|
@ -132,7 +133,11 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
return resp
|
||||
|
||||
@GoogleCloudBaseHook.fallback_to_default_project_id
|
||||
def lookup(self, keys, read_consistency=None, transaction=None, project_id=None):
|
||||
def lookup(self,
|
||||
keys: List,
|
||||
read_consistency: Optional[str] = None,
|
||||
transaction: Optional[str] = None,
|
||||
project_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Lookup some entities by key.
|
||||
|
||||
|
@ -151,9 +156,9 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
:return: the response body of the lookup request.
|
||||
:rtype: dict
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn = self.get_conn() # type: Any
|
||||
|
||||
body = {'keys': keys}
|
||||
body = {'keys': keys} # type: Dict[str, Any]
|
||||
if read_consistency:
|
||||
body['readConsistency'] = read_consistency
|
||||
if transaction:
|
||||
|
@ -166,7 +171,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
return resp
|
||||
|
||||
@GoogleCloudBaseHook.fallback_to_default_project_id
|
||||
def rollback(self, transaction, project_id=None):
|
||||
def rollback(self, transaction: str, project_id: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Roll back a transaction.
|
||||
|
||||
|
@ -178,14 +183,14 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
:param project_id: Google Cloud Platform project ID against which to make the request.
|
||||
:type project_id: str
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn = self.get_conn() # type: Any
|
||||
|
||||
conn.projects().rollback( # pylint:disable=no-member
|
||||
projectId=project_id, body={'transaction': transaction}
|
||||
).execute(num_retries=self.num_retries)
|
||||
|
||||
@GoogleCloudBaseHook.fallback_to_default_project_id
|
||||
def run_query(self, body, project_id=None):
|
||||
def run_query(self, body: Dict, project_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Run a query for entities.
|
||||
|
||||
|
@ -199,7 +204,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
:return: the batch of query results.
|
||||
:rtype: dict
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn = self.get_conn() # type: Any
|
||||
|
||||
resp = (conn # pylint:disable=no-member
|
||||
.projects()
|
||||
|
@ -208,7 +213,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
|
||||
return resp['batch']
|
||||
|
||||
def get_operation(self, name):
|
||||
def get_operation(self, name: str) -> Dict:
|
||||
"""
|
||||
Gets the latest state of a long-running operation.
|
||||
|
||||
|
@ -220,7 +225,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
:return: a resource operation instance.
|
||||
:rtype: dict
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn = self.get_conn() # type: Any
|
||||
|
||||
resp = (conn # pylint:disable=no-member
|
||||
.projects()
|
||||
|
@ -230,7 +235,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
|
||||
return resp
|
||||
|
||||
def delete_operation(self, name):
|
||||
def delete_operation(self, name: str) -> Dict:
|
||||
"""
|
||||
Deletes the long-running operation.
|
||||
|
||||
|
@ -242,7 +247,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
:return: none if successful.
|
||||
:rtype: dict
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
conn = self.get_conn() # type: Any
|
||||
|
||||
resp = (conn # pylint:disable=no-member
|
||||
.projects()
|
||||
|
@ -252,7 +257,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
|
||||
return resp
|
||||
|
||||
def poll_operation_until_done(self, name, polling_interval_in_seconds):
|
||||
def poll_operation_until_done(self, name: str, polling_interval_in_seconds: int) -> Dict:
|
||||
"""
|
||||
Poll backup operation state until it's completed.
|
||||
|
||||
|
@ -264,9 +269,9 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
:rtype: dict
|
||||
"""
|
||||
while True:
|
||||
result = self.get_operation(name)
|
||||
result = self.get_operation(name) # type: Dict
|
||||
|
||||
state = result['metadata']['common']['state']
|
||||
state = result['metadata']['common']['state'] # type: str
|
||||
if state == 'PROCESSING':
|
||||
self.log.info('Operation is processing. Re-polling state in {} seconds'
|
||||
.format(polling_interval_in_seconds))
|
||||
|
@ -275,8 +280,12 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
return result
|
||||
|
||||
@GoogleCloudBaseHook.fallback_to_default_project_id
|
||||
def export_to_storage_bucket(self, bucket, namespace=None, entity_filter=None,
|
||||
labels=None, project_id=None):
|
||||
def export_to_storage_bucket(self,
|
||||
bucket: str,
|
||||
namespace: Optional[str] = None,
|
||||
entity_filter: Optional[Dict] = None,
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
project_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Export entities from Cloud Datastore to Cloud Storage for backup.
|
||||
|
||||
|
@ -299,9 +308,9 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
:return: a resource operation instance.
|
||||
:rtype: dict
|
||||
"""
|
||||
admin_conn = self.get_conn()
|
||||
admin_conn = self.get_conn() # type: Any
|
||||
|
||||
output_uri_prefix = 'gs://' + '/'.join(filter(None, [bucket, namespace]))
|
||||
output_uri_prefix = 'gs://' + '/'.join(filter(None, [bucket, namespace])) # type: str
|
||||
if not entity_filter:
|
||||
entity_filter = {}
|
||||
if not labels:
|
||||
|
@ -310,7 +319,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
'outputUrlPrefix': output_uri_prefix,
|
||||
'entityFilter': entity_filter,
|
||||
'labels': labels,
|
||||
}
|
||||
} # type: Dict
|
||||
resp = (admin_conn # pylint:disable=no-member
|
||||
.projects()
|
||||
.export(projectId=project_id, body=body)
|
||||
|
@ -319,8 +328,13 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
return resp
|
||||
|
||||
@GoogleCloudBaseHook.fallback_to_default_project_id
|
||||
def import_from_storage_bucket(self, bucket, file, namespace=None,
|
||||
entity_filter=None, labels=None, project_id=None):
|
||||
def import_from_storage_bucket(self,
|
||||
bucket: str,
|
||||
file: str,
|
||||
namespace: Optional[str] = None,
|
||||
entity_filter: Optional[Dict] = None,
|
||||
labels: Optional[Union[Dict, str]] = None,
|
||||
project_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Import a backup from Cloud Storage to Cloud Datastore.
|
||||
|
||||
|
@ -345,9 +359,9 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
:return: a resource operation instance.
|
||||
:rtype: dict
|
||||
"""
|
||||
admin_conn = self.get_conn()
|
||||
admin_conn = self.get_conn() # type: Any
|
||||
|
||||
input_url = 'gs://' + '/'.join(filter(None, [bucket, namespace, file]))
|
||||
input_url = 'gs://' + '/'.join(filter(None, [bucket, namespace, file])) # type: str
|
||||
if not entity_filter:
|
||||
entity_filter = {}
|
||||
if not labels:
|
||||
|
@ -356,7 +370,7 @@ class DatastoreHook(GoogleCloudBaseHook):
|
|||
'inputUrl': input_url,
|
||||
'entityFilter': entity_filter,
|
||||
'labels': labels,
|
||||
}
|
||||
} # type: Dict
|
||||
resp = (admin_conn # pylint:disable=no-member
|
||||
.projects()
|
||||
.import_(projectId=project_id, body=body)
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
"""
|
||||
This module contains a Google Pub/Sub Hook.
|
||||
"""
|
||||
|
||||
from typing import Any, List, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
@ -53,7 +53,7 @@ class PubSubHook(GoogleCloudBaseHook):
|
|||
def __init__(self, gcp_conn_id: str = 'google_cloud_default', delegate_to: str = None) -> None:
|
||||
super().__init__(gcp_conn_id, delegate_to=delegate_to)
|
||||
|
||||
def get_conn(self):
|
||||
def get_conn(self) -> Any:
|
||||
"""
|
||||
Returns a Pub/Sub service object.
|
||||
|
||||
|
@ -63,7 +63,7 @@ class PubSubHook(GoogleCloudBaseHook):
|
|||
return build(
|
||||
'pubsub', 'v1', http=http_authorized, cache_discovery=False)
|
||||
|
||||
def publish(self, project, topic, messages):
|
||||
def publish(self, project: str, topic: str, messages: List[Dict]) -> None:
|
||||
"""
|
||||
Publishes messages to a Pub/Sub topic.
|
||||
|
||||
|
@ -87,7 +87,7 @@ class PubSubHook(GoogleCloudBaseHook):
|
|||
raise PubSubException(
|
||||
'Error publishing to topic {}'.format(full_topic), e)
|
||||
|
||||
def create_topic(self, project, topic, fail_if_exists=False):
|
||||
def create_topic(self, project: str, topic: str, fail_if_exists: bool = False) -> None:
|
||||
"""
|
||||
Creates a Pub/Sub topic, if it does not already exist.
|
||||
|
||||
|
@ -117,7 +117,7 @@ class PubSubHook(GoogleCloudBaseHook):
|
|||
raise PubSubException(
|
||||
'Error creating topic {}'.format(full_topic), e)
|
||||
|
||||
def delete_topic(self, project, topic, fail_if_not_exists=False):
|
||||
def delete_topic(self, project: str, topic: str, fail_if_not_exists: bool = False) -> None:
|
||||
"""
|
||||
Deletes a Pub/Sub topic if it exists.
|
||||
|
||||
|
@ -146,9 +146,15 @@ class PubSubHook(GoogleCloudBaseHook):
|
|||
raise PubSubException(
|
||||
'Error deleting topic {}'.format(full_topic), e)
|
||||
|
||||
def create_subscription(self, topic_project, topic, subscription=None,
|
||||
subscription_project=None, ack_deadline_secs=10,
|
||||
fail_if_exists=False):
|
||||
def create_subscription(
|
||||
self,
|
||||
topic_project: str,
|
||||
topic: str,
|
||||
subscription: Optional[str] = None,
|
||||
subscription_project: Optional[str] = None,
|
||||
ack_deadline_secs: int = 10,
|
||||
fail_if_exists: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Creates a Pub/Sub subscription, if it does not already exist.
|
||||
|
||||
|
@ -204,8 +210,7 @@ class PubSubHook(GoogleCloudBaseHook):
|
|||
e)
|
||||
return subscription
|
||||
|
||||
def delete_subscription(self, project, subscription,
|
||||
fail_if_not_exists=False):
|
||||
def delete_subscription(self, project: str, subscription: str, fail_if_not_exists: bool = False) -> None:
|
||||
"""
|
||||
Deletes a Pub/Sub subscription, if it exists.
|
||||
|
||||
|
@ -236,8 +241,9 @@ class PubSubHook(GoogleCloudBaseHook):
|
|||
'Error deleting subscription {}'.format(full_subscription),
|
||||
e)
|
||||
|
||||
def pull(self, project, subscription, max_messages,
|
||||
return_immediately=False):
|
||||
def pull(
|
||||
self, project: str, subscription: str, max_messages: int, return_immediately: bool = False
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Pulls up to ``max_messages`` messages from Pub/Sub subscription.
|
||||
|
||||
|
@ -273,7 +279,7 @@ class PubSubHook(GoogleCloudBaseHook):
|
|||
'Error pulling messages from subscription {}'.format(
|
||||
full_subscription), e)
|
||||
|
||||
def acknowledge(self, project, subscription, ack_ids):
|
||||
def acknowledge(self, project: str, subscription: str, ack_ids: List) -> None:
|
||||
"""
|
||||
Pulls up to ``max_messages`` messages from Pub/Sub subscription.
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче