[AIRFLOW-5424] Type annotations for GCP hooks

This commit is contained in:
Tobiasz Kędzierski 2019-08-28 17:41:09 +02:00 коммит произвёл Jarek Potiuk
Родитель 8e74ad7e64
Коммит f4a6586429
5 изменённых файлов: 314 добавлений и 260 удалений

Просмотреть файл

@ -24,9 +24,11 @@ implementation for BigQuery.
import time
from copy import deepcopy
from typing import Any, NoReturn, Mapping, Union, Iterable, Dict, List, Optional, Tuple, Type
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from pandas import DataFrame
from pandas_gbq.gbq import \
_check_google_client_version as gbq_check_google_client_version
from pandas_gbq import read_gbq
@ -45,19 +47,19 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
Interact with BigQuery. This hook uses the Google Cloud Platform
connection.
"""
conn_name_attr = 'bigquery_conn_id'
conn_name_attr = 'bigquery_conn_id' # type: str
def __init__(self,
bigquery_conn_id='google_cloud_default',
delegate_to=None,
use_legacy_sql=True,
location=None):
bigquery_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
use_legacy_sql: bool = True,
location: Optional[str] = None) -> None:
super().__init__(
gcp_conn_id=bigquery_conn_id, delegate_to=delegate_to)
self.use_legacy_sql = use_legacy_sql
self.location = location
def get_conn(self):
def get_conn(self) -> "BigQueryConnection":
"""
Returns a BigQuery PEP 249 connection object.
"""
@ -70,7 +72,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
num_retries=self.num_retries
)
def get_service(self):
def get_service(self) -> Any:
"""
Returns a BigQuery service object.
"""
@ -78,7 +80,9 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
return build(
'bigquery', 'v2', http=http_authorized, cache_discovery=False)
def insert_rows(self, table, rows, target_fields=None, commit_every=1000):
def insert_rows(
self, table: Any, rows: Any, target_fields: Any = None, commit_every: Any = 1000, replace: Any = False
) -> NoReturn:
"""
Insertion is currently unsupported. Theoretically, you could use
BigQuery's streaming API to insert rows into a table, but this hasn't
@ -86,7 +90,9 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
"""
raise NotImplementedError()
def get_pandas_df(self, sql, parameters=None, dialect=None):
def get_pandas_df(
self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None, dialect: Optional[str] = None
) -> DataFrame:
"""
Returns a Pandas DataFrame for the results produced by a BigQuery
query. The DbApiHook method must be overridden because Pandas
@ -115,7 +121,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
verbose=False,
credentials=credentials)
def table_exists(self, project_id, dataset_id, table_id):
def table_exists(self, project_id: str, dataset_id: str, table_id: str) -> bool:
"""
Checks for the existence of a table in Google BigQuery.
@ -150,12 +156,9 @@ class BigQueryPandasConnector(GbqConnector):
service account credentials into the binding.
"""
def __init__(self,
project_id,
service,
reauth=False,
verbose=False,
dialect='legacy'):
def __init__(
self, project_id: str, service: str, reauth: bool = False, verbose: bool = False, dialect="legacy"
) -> None:
super().__init__(project_id)
gbq_check_google_client_version()
gbq_test_google_api_imports()
@ -173,21 +176,21 @@ class BigQueryConnection:
work.
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
self._args = args
self._kwargs = kwargs
def close(self):
def close(self) -> None:
""" BigQueryConnection does not have anything to close. """
def commit(self):
def commit(self) -> None:
""" BigQueryConnection does not support transactions. """
def cursor(self):
def cursor(self) -> "BigQueryCursor":
""" Return a new :py:class:`Cursor` object using the connection. """
return BigQueryCursor(*self._args, **self._kwargs)
def rollback(self):
def rollback(self) -> NoReturn:
""" BigQueryConnection does not have transactions """
raise NotImplementedError(
"BigQueryConnection does not have transactions")
@ -201,12 +204,12 @@ class BigQueryBaseCursor(LoggingMixin):
"""
def __init__(self,
service,
project_id,
use_legacy_sql=True,
api_resource_configs=None,
location=None,
num_retries=5):
service: Any,
project_id: str,
use_legacy_sql: bool = True,
api_resource_configs: Optional[Dict] = None,
location: Optional[str] = None,
num_retries: int = 5) -> None:
self.service = service
self.project_id = project_id
@ -214,23 +217,23 @@ class BigQueryBaseCursor(LoggingMixin):
if api_resource_configs:
_validate_value("api_resource_configs", api_resource_configs, dict)
self.api_resource_configs = api_resource_configs \
if api_resource_configs else {}
self.running_job_id = None
if api_resource_configs else {} # type Dict
self.running_job_id = None # type: Optional[str]
self.location = location
self.num_retries = num_retries
# pylint: disable=too-many-arguments
def create_empty_table(self,
project_id,
dataset_id,
table_id,
schema_fields=None,
time_partitioning=None,
cluster_fields=None,
labels=None,
view=None,
encryption_configuration=None,
num_retries=5):
project_id: str,
dataset_id: str,
table_id: str,
schema_fields: Optional[List] = None,
time_partitioning: Optional[Dict] = None,
cluster_fields: Optional[List] = None,
labels: Optional[Dict] = None,
view: Optional[Dict] = None,
encryption_configuration: Optional[Dict] = None,
num_retries: int = 5) -> None:
"""
Creates a new, empty table in the dataset.
To create a view, which is defined by a SQL query, parse a dictionary to 'view' kwarg
@ -291,7 +294,7 @@ class BigQueryBaseCursor(LoggingMixin):
'tableReference': {
'tableId': table_id
}
}
} # type: Dict[str, Any]
if schema_fields:
table_resource['schema'] = {'fields': schema_fields}
@ -333,23 +336,23 @@ class BigQueryBaseCursor(LoggingMixin):
)
def create_external_table(self, # pylint: disable=too-many-locals,too-many-arguments
external_project_dataset_table,
schema_fields,
source_uris,
source_format='CSV',
autodetect=False,
compression='NONE',
ignore_unknown_values=False,
max_bad_records=0,
skip_leading_rows=0,
field_delimiter=',',
quote_character=None,
allow_quoted_newlines=False,
allow_jagged_rows=False,
src_fmt_configs=None,
labels=None,
encryption_configuration=None
):
external_project_dataset_table: str,
schema_fields: List,
source_uris: List,
source_format: str = 'CSV',
autodetect: bool = False,
compression: str = 'NONE',
ignore_unknown_values: bool = False,
max_bad_records: int = 0,
skip_leading_rows: int = 0,
field_delimiter: str = ',',
quote_character: Optional[str] = None,
allow_quoted_newlines: bool = False,
allow_jagged_rows: bool = False,
src_fmt_configs: Optional[Dict] = None,
labels: Optional[Dict] = None,
encryption_configuration: Optional[Dict] = None
) -> None:
"""
Creates a new external table in the dataset with the data in Google
Cloud Storage. See here:
@ -437,14 +440,14 @@ class BigQueryBaseCursor(LoggingMixin):
allowed_formats = [
"CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS",
"DATASTORE_BACKUP", "PARQUET"
]
] # type: List[str]
if source_format not in allowed_formats:
raise ValueError("{0} is not a valid source format. "
"Please use one of the following types: {1}"
.format(source_format, allowed_formats))
compression = compression.upper()
allowed_compressions = ['NONE', 'GZIP']
allowed_compressions = ['NONE', 'GZIP'] # type: List[str]
if compression not in allowed_compressions:
raise ValueError("{0} is not a valid compression format. "
"Please use one of the following types: {1}"
@ -463,7 +466,7 @@ class BigQueryBaseCursor(LoggingMixin):
'datasetId': dataset_id,
'tableId': external_table_id,
}
}
} # type: Dict[str, Any]
if schema_fields:
table_resource['externalDataConfiguration'].update({
@ -532,19 +535,19 @@ class BigQueryBaseCursor(LoggingMixin):
)
def patch_table(self, # pylint: disable=too-many-arguments
dataset_id,
table_id,
project_id=None,
description=None,
expiration_time=None,
external_data_configuration=None,
friendly_name=None,
labels=None,
schema=None,
time_partitioning=None,
view=None,
require_partition_filter=None,
encryption_configuration=None):
dataset_id: str,
table_id: str,
project_id: Optional[str] = None,
description: Optional[str] = None,
expiration_time: Optional[int] = None,
external_data_configuration: Optional[Dict] = None,
friendly_name: Optional[str] = None,
labels: Optional[Dict] = None,
schema: Optional[List] = None,
time_partitioning: Optional[Dict] = None,
view: Optional[Dict] = None,
require_partition_filter: Optional[bool] = None,
encryption_configuration: Optional[Dict] = None) -> None:
"""
Patch information in an existing table.
It only updates fileds that are provided in the request object.
@ -608,7 +611,7 @@ class BigQueryBaseCursor(LoggingMixin):
project_id = project_id if project_id is not None else self.project_id
table_resource = {}
table_resource = {} # type: Dict[str, Any]
if description is not None:
table_resource['description'] = description
@ -651,25 +654,25 @@ class BigQueryBaseCursor(LoggingMixin):
# pylint: disable=too-many-locals,too-many-arguments, too-many-branches
def run_query(self,
sql,
destination_dataset_table=None,
write_disposition='WRITE_EMPTY',
allow_large_results=False,
flatten_results=None,
udf_config=None,
use_legacy_sql=None,
maximum_billing_tier=None,
maximum_bytes_billed=None,
create_disposition='CREATE_IF_NEEDED',
query_params=None,
labels=None,
schema_update_options=None,
priority='INTERACTIVE',
time_partitioning=None,
api_resource_configs=None,
cluster_fields=None,
location=None,
encryption_configuration=None):
sql: str,
destination_dataset_table: Optional[str] = None,
write_disposition: str = 'WRITE_EMPTY',
allow_large_results: bool = False,
flatten_results: Optional[bool] = None,
udf_config: Optional[List] = None,
use_legacy_sql: Optional[bool] = None,
maximum_billing_tier: Optional[int] = None,
maximum_bytes_billed: Optional[float] = None,
create_disposition: str = 'CREATE_IF_NEEDED',
query_params: Optional[List] = None,
labels: Optional[Dict] = None,
schema_update_options: Optional[Iterable] = None,
priority: str = 'INTERACTIVE',
time_partitioning: Optional[Dict] = None,
api_resource_configs: Optional[Dict] = None,
cluster_fields: Optional[List[str]] = None,
location: Optional[str] = None,
encryption_configuration: Optional[Dict] = None) -> str:
"""
Executes a BigQuery SQL query. Optionally persists results in a BigQuery
table. See here:
@ -802,14 +805,14 @@ class BigQueryBaseCursor(LoggingMixin):
_split_tablename(table_input=destination_dataset_table,
default_project_id=self.project_id)
destination_dataset_table = {
destination_dataset_table = { # type: ignore
'projectId': destination_project,
'datasetId': destination_dataset,
'tableId': destination_table,
}
if cluster_fields:
cluster_fields = {'fields': cluster_fields}
cluster_fields = {'fields': cluster_fields} # type: ignore
query_param_list = [
(sql, 'query', None, (str,)),
@ -823,7 +826,7 @@ class BigQueryBaseCursor(LoggingMixin):
(schema_update_options, 'schemaUpdateOptions', None, list),
(destination_dataset_table, 'destinationTable', None, dict),
(cluster_fields, 'clustering', None, dict),
]
] # type: List[Tuple]
for param, param_name, param_default, param_type in query_param_list:
if param_name not in configuration['query'] and param in [None, {}, ()]:
@ -888,13 +891,13 @@ class BigQueryBaseCursor(LoggingMixin):
def run_extract( # noqa
self,
source_project_dataset_table,
destination_cloud_storage_uris,
compression='NONE',
export_format='CSV',
field_delimiter=',',
print_header=True,
labels=None):
source_project_dataset_table: str,
destination_cloud_storage_uris: str,
compression: str = 'NONE',
export_format: str = 'CSV',
field_delimiter: str = ',',
print_header: bool = True,
labels: Optional[Dict] = None) -> str:
"""
Executes a BigQuery extract command to copy data from BigQuery to
Google Cloud Storage. See here:
@ -940,7 +943,7 @@ class BigQueryBaseCursor(LoggingMixin):
'destinationUris': destination_cloud_storage_uris,
'destinationFormat': export_format,
}
}
} # type: Dict[str, Any]
if labels:
configuration['labels'] = labels
@ -955,12 +958,12 @@ class BigQueryBaseCursor(LoggingMixin):
return self.run_with_configuration(configuration)
def run_copy(self, # pylint: disable=invalid-name
source_project_dataset_tables,
destination_project_dataset_table,
write_disposition='WRITE_EMPTY',
create_disposition='CREATE_IF_NEEDED',
labels=None,
encryption_configuration=None):
source_project_dataset_tables: Union[List, str],
destination_project_dataset_table: str,
write_disposition: str = 'WRITE_EMPTY',
create_disposition: str = 'CREATE_IF_NEEDED',
labels: Optional[Dict] = None,
encryption_configuration: Optional[Dict] = None) -> str:
"""
Executes a BigQuery copy command to copy data from one BigQuery table
to another. See here:
@ -1041,25 +1044,25 @@ class BigQueryBaseCursor(LoggingMixin):
return self.run_with_configuration(configuration)
def run_load(self, # pylint: disable=too-many-locals,too-many-arguments,invalid-name
destination_project_dataset_table,
source_uris,
schema_fields=None,
source_format='CSV',
create_disposition='CREATE_IF_NEEDED',
skip_leading_rows=0,
write_disposition='WRITE_EMPTY',
field_delimiter=',',
max_bad_records=0,
quote_character=None,
ignore_unknown_values=False,
allow_quoted_newlines=False,
allow_jagged_rows=False,
schema_update_options=None,
src_fmt_configs=None,
time_partitioning=None,
cluster_fields=None,
autodetect=False,
encryption_configuration=None):
destination_project_dataset_table: str,
source_uris: List,
schema_fields: Optional[List] = None,
source_format: str = 'CSV',
create_disposition: str = 'CREATE_IF_NEEDED',
skip_leading_rows: int = 0,
write_disposition: str = 'WRITE_EMPTY',
field_delimiter: str = ',',
max_bad_records: int = 0,
quote_character: Optional[str] = None,
ignore_unknown_values: bool = False,
allow_quoted_newlines: bool = False,
allow_jagged_rows: bool = False,
schema_update_options: Optional[Iterable] = None,
src_fmt_configs: Optional[Dict] = None,
time_partitioning: Optional[Dict] = None,
cluster_fields: Optional[List] = None,
autodetect: bool = False,
encryption_configuration: Optional[Dict] = None) -> str:
"""
Executes a BigQuery load command to load data from Google Cloud Storage
to BigQuery. See here:
@ -1234,6 +1237,19 @@ class BigQueryBaseCursor(LoggingMixin):
"destinationEncryptionConfiguration"
] = encryption_configuration
# if following fields are not specified in src_fmt_configs,
# honor the top-level params for backward-compatibility
if 'skipLeadingRows' not in src_fmt_configs:
src_fmt_configs['skipLeadingRows'] = skip_leading_rows
if 'fieldDelimiter' not in src_fmt_configs:
src_fmt_configs['fieldDelimiter'] = field_delimiter
if 'ignoreUnknownValues' not in src_fmt_configs:
src_fmt_configs['ignoreUnknownValues'] = ignore_unknown_values
if quote_character is not None:
src_fmt_configs['quote'] = quote_character
if allow_quoted_newlines:
src_fmt_configs['allowQuotedNewlines'] = allow_quoted_newlines
src_fmt_to_configs_mapping = {
'CSV': [
'allowJaggedRows', 'allowQuotedNewlines', 'autodetect',
@ -1266,7 +1282,7 @@ class BigQueryBaseCursor(LoggingMixin):
return self.run_with_configuration(configuration)
def run_with_configuration(self, configuration):
def run_with_configuration(self, configuration: Dict) -> str:
"""
Executes a BigQuery SQL query. See here:
@ -1279,8 +1295,8 @@ class BigQueryBaseCursor(LoggingMixin):
https://cloud.google.com/bigquery/docs/reference/v2/jobs for
details.
"""
jobs = self.service.jobs()
job_data = {'configuration': configuration}
jobs = self.service.jobs() # type: Any
job_data = {'configuration': configuration} # type: Dict[str, Dict]
# Send query and wait for reply.
query_reply = jobs \
@ -1293,7 +1309,7 @@ class BigQueryBaseCursor(LoggingMixin):
location = self.location
# Wait for query to finish.
keep_polling_job = True
keep_polling_job = True # type: bool
while keep_polling_job:
try:
keep_polling_job = self._check_query_status(jobs, keep_polling_job, location)
@ -1309,9 +1325,9 @@ class BigQueryBaseCursor(LoggingMixin):
'BigQuery job status check failed. Final error was: {}'.
format(err.resp.status))
return self.running_job_id
return self.running_job_id # type: ignore
def _check_query_status(self, jobs, keep_polling_job, location):
def _check_query_status(self, jobs: Any, keep_polling_job: bool, location: str) -> bool:
if location:
job = jobs.get(
projectId=self.project_id,
@ -1335,7 +1351,7 @@ class BigQueryBaseCursor(LoggingMixin):
time.sleep(5)
return keep_polling_job
def poll_job_complete(self, job_id):
def poll_job_complete(self, job_id: str) -> bool:
"""
Check if jobs completed.
@ -1365,7 +1381,7 @@ class BigQueryBaseCursor(LoggingMixin):
format(err.resp.status))
return False
def cancel_query(self):
def cancel_query(self) -> None:
"""
Cancel all started queries that have not yet completed
"""
@ -1408,7 +1424,7 @@ class BigQueryBaseCursor(LoggingMixin):
self.running_job_id)
time.sleep(5)
def get_schema(self, dataset_id, table_id):
def get_schema(self, dataset_id: str, table_id: str) -> Dict:
"""
Get the schema for a given datset.table.
see https://cloud.google.com/bigquery/docs/reference/v2/tables#resource
@ -1422,9 +1438,9 @@ class BigQueryBaseCursor(LoggingMixin):
.execute(num_retries=self.num_retries)
return tables_resource['schema']
def get_tabledata(self, dataset_id, table_id,
max_results=None, selected_fields=None, page_token=None,
start_index=None):
def get_tabledata(self, dataset_id: str, table_id: str,
max_results: Optional[int] = None, selected_fields: Optional[str] = None,
page_token: Optional[str] = None, start_index: Optional[int] = None) -> Dict:
"""
Get the data of a given dataset.table and optionally with selected columns.
see https://cloud.google.com/bigquery/docs/reference/v2/tabledata/list
@ -1439,7 +1455,7 @@ class BigQueryBaseCursor(LoggingMixin):
:param start_index: zero based index of the starting row to read.
:return: map containing the requested rows.
"""
optional_params = {}
optional_params = {} # type: Dict[str, Any]
if max_results:
optional_params['maxResults'] = max_results
if selected_fields:
@ -1454,8 +1470,8 @@ class BigQueryBaseCursor(LoggingMixin):
tableId=table_id,
**optional_params).execute(num_retries=self.num_retries))
def run_table_delete(self, deletion_dataset_table,
ignore_if_missing=False):
def run_table_delete(self, deletion_dataset_table: str,
ignore_if_missing: bool = False) -> None:
"""
Delete an existing table from the dataset;
If the table does not exist, return an error unless ignore_if_missing
@ -1488,7 +1504,8 @@ class BigQueryBaseCursor(LoggingMixin):
else:
self.log.info('Table does not exist. Skipping.')
def run_table_upsert(self, dataset_id, table_resource, project_id=None):
def run_table_upsert(self, dataset_id: str, table_resource: Dict,
project_id: Optional[str] = None) -> Dict:
"""
creates a new, empty table in the dataset;
If the table already exists, update the existing table.
@ -1538,11 +1555,11 @@ class BigQueryBaseCursor(LoggingMixin):
body=table_resource).execute(num_retries=self.num_retries)
def run_grant_dataset_view_access(self,
source_dataset,
view_dataset,
view_table,
source_project=None,
view_project=None):
source_dataset: str,
view_dataset: str,
view_table: str,
source_project: Optional[str] = None,
view_project: Optional[str] = None) -> Dict:
"""
Grant authorized view access of a dataset to a view table.
If this view has already been granted access to the dataset, do nothing.
@ -1600,8 +1617,11 @@ class BigQueryBaseCursor(LoggingMixin):
view_project, view_dataset, view_table, source_project, source_dataset)
return source_dataset_resource
def create_empty_dataset(self, dataset_id="", project_id="",
location=None, dataset_reference=None):
def create_empty_dataset(self,
dataset_id: str = "",
project_id: str = "",
location: Optional[str] = None,
dataset_reference: Optional[Dict] = None) -> None:
"""
Create a new empty dataset:
https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/insert
@ -1659,9 +1679,8 @@ class BigQueryBaseCursor(LoggingMixin):
'location', location,
dataset_reference, 'dataset_reference')
dataset_id = dataset_reference.get("datasetReference").get("datasetId")
dataset_project_id = dataset_reference.get("datasetReference").get(
"projectId")
dataset_id = dataset_reference.get("datasetReference").get("datasetId") # type: ignore
dataset_project_id = dataset_reference.get("datasetReference").get("projectId") # type: ignore
self.log.info('Creating Dataset: %s in project: %s ', dataset_id,
dataset_project_id)
@ -1678,7 +1697,7 @@ class BigQueryBaseCursor(LoggingMixin):
'BigQuery job failed. Error was: {}'.format(err.content)
)
def delete_dataset(self, project_id, dataset_id, delete_contents=False):
def delete_dataset(self, project_id: str, dataset_id: str, delete_contents: bool = False) -> None:
"""
Delete a dataset of Big query in your project.
@ -1710,7 +1729,7 @@ class BigQueryBaseCursor(LoggingMixin):
'BigQuery job failed. Error was: {}'.format(err.content)
)
def get_dataset(self, dataset_id, project_id=None):
def get_dataset(self, dataset_id: str, project_id: Optional[str] = None) -> Dict:
"""
Method returns dataset_resource if dataset exist
and raised 404 error if dataset does not exist
@ -1742,7 +1761,7 @@ class BigQueryBaseCursor(LoggingMixin):
return dataset_resource
def get_datasets_list(self, project_id=None):
def get_datasets_list(self, project_id: Optional[str] = None) -> List:
"""
Method returns full list of BigQuery datasets in the current project
@ -1790,7 +1809,7 @@ class BigQueryBaseCursor(LoggingMixin):
return datasets_list
def patch_dataset(self, dataset_id, dataset_resource, project_id=None):
def patch_dataset(self, dataset_id: str, dataset_resource: str, project_id: Optional[str] = None) -> Dict:
"""
Patches information in an existing dataset.
It only replaces fields that are provided in the submitted dataset resource.
@ -1835,7 +1854,8 @@ class BigQueryBaseCursor(LoggingMixin):
return dataset
def update_dataset(self, dataset_id, dataset_resource, project_id=None):
def update_dataset(self, dataset_id: str,
dataset_resource: Dict, project_id: Optional[str] = None) -> Dict:
"""
Updates information in an existing dataset. The update method replaces the entire
dataset resource, whereas the patch method only replaces fields that are provided
@ -1881,9 +1901,9 @@ class BigQueryBaseCursor(LoggingMixin):
return dataset
def insert_all(self, project_id, dataset_id, table_id,
rows, ignore_unknown_values=False,
skip_invalid_rows=False, fail_on_error=False):
def insert_all(self, project_id: str, dataset_id: str, table_id: str,
rows: List, ignore_unknown_values: bool = False,
skip_invalid_rows: bool = False, fail_on_error: bool = False) -> None:
"""
Method to stream data into BigQuery one record at a time without needing
to run a load job
@ -1967,7 +1987,14 @@ class BigQueryCursor(BigQueryBaseCursor):
https://github.com/dropbox/PyHive/blob/master/pyhive/common.py
"""
def __init__(self, service, project_id, use_legacy_sql=True, location=None, num_retries=5):
def __init__(
self,
service: Any,
project_id: str,
use_legacy_sql: bool = True,
location: Optional[str] = None,
num_retries: int = 5,
) -> None:
super().__init__(
service=service,
project_id=project_id,
@ -1975,26 +2002,26 @@ class BigQueryCursor(BigQueryBaseCursor):
location=location,
num_retries=num_retries
)
self.buffersize = None
self.page_token = None
self.job_id = None
self.buffer = []
self.all_pages_loaded = False
self.buffersize = None # type: Optional[int]
self.page_token = None # type: Optional[str]
self.job_id = None # type: Optional[str]
self.buffer = [] # type: list
self.all_pages_loaded = False # type: bool
@property
def description(self):
def description(self) -> NoReturn:
""" The schema description method is not currently implemented. """
raise NotImplementedError
def close(self):
def close(self) -> None:
""" By default, do nothing """
@property
def rowcount(self):
def rowcount(self) -> int:
""" By default, return -1 to indicate that this is not supported. """
return -1
def execute(self, operation, parameters=None):
def execute(self, operation: str, parameters: Optional[Dict] = None) -> None:
"""
Executes a BigQuery query, and returns the job ID.
@ -2007,7 +2034,7 @@ class BigQueryCursor(BigQueryBaseCursor):
parameters) if parameters else operation
self.job_id = self.run_query(sql)
def executemany(self, operation, seq_of_parameters):
def executemany(self, operation: str, seq_of_parameters: List) -> None:
"""
Execute a BigQuery query multiple times with different parameters.
@ -2020,11 +2047,11 @@ class BigQueryCursor(BigQueryBaseCursor):
for parameters in seq_of_parameters:
self.execute(operation, parameters)
def fetchone(self):
def fetchone(self) -> Union[List, None]:
""" Fetch the next row of a query result set. """
return self.next()
def next(self):
def next(self) -> Union[List, None]:
"""
Helper method for fetchone, which returns the next row from a buffer.
If the buffer is empty, attempts to paginate through the result set for
@ -2067,7 +2094,7 @@ class BigQueryCursor(BigQueryBaseCursor):
return self.buffer.pop(0)
def fetchmany(self, size=None):
def fetchmany(self, size: Optional[int] = None) -> List:
"""
Fetch the next set of rows of a query result, returning a sequence of sequences
(e.g. a list of tuples). An empty sequence is returned when no more rows are
@ -2090,7 +2117,7 @@ class BigQueryCursor(BigQueryBaseCursor):
result.append(one)
return result
def fetchall(self):
def fetchall(self) -> List[List]:
"""
Fetch all (remaining) rows of a query result, returning them as a sequence of
sequences (e.g. a list of tuples).
@ -2104,27 +2131,27 @@ class BigQueryCursor(BigQueryBaseCursor):
result.append(one)
return result
def get_arraysize(self):
def get_arraysize(self) -> int:
""" Specifies the number of rows to fetch at a time with .fetchmany() """
return self.buffersize or 1
def set_arraysize(self, arraysize):
def set_arraysize(self, arraysize: int) -> None:
""" Specifies the number of rows to fetch at a time with .fetchmany() """
self.buffersize = arraysize
arraysize = property(get_arraysize, set_arraysize)
def setinputsizes(self, sizes):
def setinputsizes(self, sizes: Any) -> None:
""" Does nothing by default """
def setoutputsize(self, size, column=None):
def setoutputsize(self, size: Any, column: Any = None) -> None:
""" Does nothing by default """
def _bind_parameters(operation, parameters):
def _bind_parameters(operation: str, parameters: Dict) -> str:
""" Helper method that binds parameters to a SQL query. """
# inspired by MySQL Python Connector (conversion.py)
string_parameters = {}
string_parameters = {} # type Dict[str, str]
for (name, value) in parameters.items():
if value is None:
string_parameters[name] = 'NULL'
@ -2135,7 +2162,7 @@ def _bind_parameters(operation, parameters):
return operation % string_parameters
def _escape(s):
def _escape(s: str) -> str:
""" Helper method that escapes parameters to a SQL query. """
e = s
e = e.replace('\\', '\\\\')
@ -2146,7 +2173,7 @@ def _escape(s):
return e
def _bq_cast(string_field, bq_type):
def _bq_cast(string_field: str, bq_type: str) -> Union[None, int, float, bool, str]:
"""
Helper method that casts a BigQuery row to the appropriate data types.
This is useful because BigQuery returns all fields as strings.
@ -2166,7 +2193,8 @@ def _bq_cast(string_field, bq_type):
return string_field
def _split_tablename(table_input, default_project_id, var_name=None):
def _split_tablename(table_input: str, default_project_id: str,
var_name: Optional[str] = None) -> Tuple[str, str, str]:
if '.' not in table_input:
raise ValueError(
@ -2231,8 +2259,9 @@ def _split_tablename(table_input, default_project_id, var_name=None):
return project_id, dataset_id, table_id
def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
# if it is a partitioned table ($ is in the table name) add partition load option
def _cleanse_time_partitioning(
destination_dataset_table: Optional[str], time_partitioning_in: Optional[Dict]
) -> Dict: # if it is a partitioned table ($ is in the table name) add partition load option
if time_partitioning_in is None:
time_partitioning_in = {}
@ -2244,7 +2273,7 @@ def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
return time_partitioning_out
def _validate_value(key, value, expected_type):
def _validate_value(key: Any, value: Any, expected_type: Type) -> None:
""" function to check expected type and raise
error if type is not correct """
if not isinstance(value, expected_type):
@ -2252,8 +2281,8 @@ def _validate_value(key, value, expected_type):
key, expected_type, type(value)))
def _api_resource_configs_duplication_check(key, value, config_dict,
config_dict_name='api_resource_configs'):
def _api_resource_configs_duplication_check(key: Any, value: Any, config_dict: Dict,
config_dict_name='api_resource_configs') -> None:
if key in config_dict and value != config_dict[key]:
raise ValueError("Values of {param_name} param are duplicated. "
"{dict_name} contained {param_name} param "
@ -2262,8 +2291,10 @@ def _api_resource_configs_duplication_check(key, value, config_dict,
.format(param_name=key, dict_name=config_dict_name))
def _validate_src_fmt_configs(source_format, src_fmt_configs, valid_configs,
backward_compatibility_configs=None):
def _validate_src_fmt_configs(source_format: str ,
src_fmt_configs: Dict,
valid_configs: List[str],
backward_compatibility_configs: Optional[Dict] = None) -> Dict
"""
Validates the given src_fmt_configs against a valid configuration for the source format.
Adds the backward compatiblity config to the src_fmt_configs.

Просмотреть файл

@ -25,7 +25,7 @@ import time
import warnings
from copy import deepcopy
from datetime import timedelta
from typing import Dict, List, Tuple, Union, Optional
from typing import Dict, List, Union, Set, Optional
from googleapiclient.discovery import build
@ -111,7 +111,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
self,
api_version: str = 'v1',
gcp_conn_id: str = 'google_cloud_default',
delegate_to: str = None
delegate_to: Optional[str] = None
) -> None:
super().__init__(gcp_conn_id, delegate_to)
self.api_version = api_version
@ -150,7 +150,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
@GoogleCloudBaseHook.fallback_to_default_project_id
@GoogleCloudBaseHook.catch_http_exception
def get_transfer_job(self, job_name: str, project_id: str = None) -> Dict:
def get_transfer_job(self, job_name: str, project_id: Optional[str] = None) -> Dict:
"""
Gets the latest state of a long-running operation in Google Storage
Transfer Service.
@ -172,7 +172,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
.execute(num_retries=self.num_retries)
)
def list_transfer_job(self, request_filter: Dict = None, **kwargs) -> List[Dict]:
def list_transfer_job(self, request_filter: Optional[Dict] = None, **kwargs) -> List[Dict]:
"""
Lists long-running operations in Google Storage Transfer
Service that match the specified filter.
@ -230,7 +230,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
@GoogleCloudBaseHook.fallback_to_default_project_id
@GoogleCloudBaseHook.catch_http_exception
def delete_transfer_job(self, job_name: str, project_id: str = None) -> None:
def delete_transfer_job(self, job_name: str, project_id: Optional[str] = None) -> None:
"""
Deletes a transfer job. This is a soft delete. After a transfer job is
deleted, the job and all the transfer executions are subject to garbage
@ -293,7 +293,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
)
@GoogleCloudBaseHook.catch_http_exception
def list_transfer_operations(self, request_filter: Dict = None, **kwargs) -> List[Dict]:
def list_transfer_operations(self, request_filter: Optional[Dict] = None, **kwargs) -> List[Dict]:
"""
Gets an transfer operation in Google Storage Transfer Service.
@ -369,7 +369,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
def wait_for_transfer_job(
self,
job: Dict,
expected_statuses: Tuple[str] = (GcpTransferOperationStatus.SUCCESS,),
expected_statuses: Optional[Set[str]] = None,
timeout: Optional[Union[float, timedelta]] = None
) -> None:
"""
@ -388,6 +388,9 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
:type timeout: Optional[Union[float, timedelta]]
:rtype: None
"""
expected_statuses = (
{GcpTransferOperationStatus.SUCCESS} if not expected_statuses else expected_statuses
)
if timeout is None:
timeout = 60
elif isinstance(timeout, timedelta):
@ -417,7 +420,7 @@ class GCPTransferServiceHook(GoogleCloudBaseHook):
@staticmethod
def operations_contain_expected_statuses(
operations: List[Dict],
expected_statuses: Union[Tuple[str], str]
expected_statuses: Union[Set[str], str]
) -> bool:
"""
Checks whether the operation list has an operation with the

Просмотреть файл

@ -26,7 +26,7 @@ import select
import subprocess
import time
import uuid
from typing import Dict, List, Callable, Any, Optional
from typing import Dict, List, Callable, Any, Optional, Union
from googleapiclient.discovery import build
@ -61,7 +61,7 @@ class _DataflowJob(LoggingMixin):
name: str,
location: str,
poll_sleep: int = 10,
job_id: str = None,
job_id: Optional[str] = None,
num_retries: int = 0,
multiple_jobs: bool = False
) -> None:
@ -75,7 +75,7 @@ class _DataflowJob(LoggingMixin):
self._poll_sleep = poll_sleep
self._jobs = self._get_jobs()
def is_job_running(self):
def is_job_running(self) -> bool:
"""
Helper method to check if jos is still running in dataflow
@ -88,7 +88,7 @@ class _DataflowJob(LoggingMixin):
return False
# pylint: disable=too-many-nested-blocks
def _get_dataflow_jobs(self):
def _get_dataflow_jobs(self) -> List:
"""
Helper method to get list of jobs that start with job name or id
@ -116,7 +116,7 @@ class _DataflowJob(LoggingMixin):
else:
raise Exception('Missing both dataflow job ID and name.')
def _get_jobs(self):
def _get_jobs(self) -> List:
"""
Helper method to get all jobs by name
@ -145,7 +145,7 @@ class _DataflowJob(LoggingMixin):
return self._jobs
# pylint: disable=too-many-nested-blocks
def check_dataflow_job_state(self, job):
def check_dataflow_job_state(self, job) -> bool:
"""
Helper method to check the state of all jobs in dataflow for this task
if job failed raise exception
@ -209,7 +209,7 @@ class _DataflowJob(LoggingMixin):
class _Dataflow(LoggingMixin):
def __init__(self, cmd) -> None:
def __init__(self, cmd: Union[List, str]) -> None:
self.log.info("Running command: %s", ' '.join(cmd))
self._proc = subprocess.Popen(
cmd,
@ -297,7 +297,7 @@ class DataFlowHook(GoogleCloudBaseHook):
def __init__(
self,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: str = None,
delegate_to: Optional[str] = None,
poll_sleep: int = 10
) -> None:
self.poll_sleep = poll_sleep
@ -328,7 +328,7 @@ class DataFlowHook(GoogleCloudBaseHook):
.wait_for_done()
@staticmethod
def _set_variables(variables: Dict):
def _set_variables(variables: Dict) -> Dict:
if variables['project'] is None:
raise Exception('Project not specified')
if 'region' not in variables.keys():
@ -340,7 +340,7 @@ class DataFlowHook(GoogleCloudBaseHook):
job_name: str,
variables: Dict,
jar: str,
job_class: str = None,
job_class: Optional[str] = None,
append_job_name: bool = True,
multiple_jobs: bool = False
) -> None:
@ -377,7 +377,7 @@ class DataFlowHook(GoogleCloudBaseHook):
variables: Dict,
parameters: Dict,
dataflow_template: str,
append_job_name=True
append_job_name: bool = True
) -> None:
"""
Starts Dataflow template job.

Просмотреть файл

@ -22,6 +22,7 @@ This module contains Google Datastore hook.
"""
import time
from typing import Any, List, Dict, Union, Optional
from googleapiclient.discovery import build
@ -40,14 +41,14 @@ class DatastoreHook(GoogleCloudBaseHook):
"""
def __init__(self,
datastore_conn_id='google_cloud_default',
delegate_to=None,
api_version='v1'):
datastore_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
api_version: str = 'v1') -> None:
super().__init__(datastore_conn_id, delegate_to)
self.connection = None
self.api_version = api_version
def get_conn(self):
def get_conn(self) -> Any:
"""
Establishes a connection to the Google API.
@ -62,7 +63,7 @@ class DatastoreHook(GoogleCloudBaseHook):
return self.connection
@GoogleCloudBaseHook.fallback_to_default_project_id
def allocate_ids(self, partial_keys, project_id=None):
def allocate_ids(self, partial_keys: List, project_id: Optional[str] = None) -> List:
"""
Allocate IDs for incomplete keys.
@ -76,7 +77,7 @@ class DatastoreHook(GoogleCloudBaseHook):
:return: a list of full keys.
:rtype: list
"""
conn = self.get_conn()
conn = self.get_conn() # type: Any
resp = (conn # pylint:disable=no-member
.projects()
@ -86,7 +87,7 @@ class DatastoreHook(GoogleCloudBaseHook):
return resp['keys']
@GoogleCloudBaseHook.fallback_to_default_project_id
def begin_transaction(self, project_id=None):
def begin_transaction(self, project_id: Optional[str] = None) -> str:
"""
Begins a new transaction.
@ -98,7 +99,7 @@ class DatastoreHook(GoogleCloudBaseHook):
:return: a transaction handle.
:rtype: str
"""
conn = self.get_conn()
conn = self.get_conn() # type: Any
resp = (conn # pylint:disable=no-member
.projects()
@ -108,7 +109,7 @@ class DatastoreHook(GoogleCloudBaseHook):
return resp['transaction']
@GoogleCloudBaseHook.fallback_to_default_project_id
def commit(self, body, project_id=None):
def commit(self, body: Dict, project_id: Optional[str] = None) -> Dict:
"""
Commit a transaction, optionally creating, deleting or modifying some entities.
@ -122,7 +123,7 @@ class DatastoreHook(GoogleCloudBaseHook):
:return: the response body of the commit request.
:rtype: dict
"""
conn = self.get_conn()
conn = self.get_conn() # type: Any
resp = (conn # pylint:disable=no-member
.projects()
@ -132,7 +133,11 @@ class DatastoreHook(GoogleCloudBaseHook):
return resp
@GoogleCloudBaseHook.fallback_to_default_project_id
def lookup(self, keys, read_consistency=None, transaction=None, project_id=None):
def lookup(self,
keys: List,
read_consistency: Optional[str] = None,
transaction: Optional[str] = None,
project_id: Optional[str] = None) -> Dict:
"""
Lookup some entities by key.
@ -151,9 +156,9 @@ class DatastoreHook(GoogleCloudBaseHook):
:return: the response body of the lookup request.
:rtype: dict
"""
conn = self.get_conn()
conn = self.get_conn() # type: Any
body = {'keys': keys}
body = {'keys': keys} # type: Dict[str, Any]
if read_consistency:
body['readConsistency'] = read_consistency
if transaction:
@ -166,7 +171,7 @@ class DatastoreHook(GoogleCloudBaseHook):
return resp
@GoogleCloudBaseHook.fallback_to_default_project_id
def rollback(self, transaction, project_id=None):
def rollback(self, transaction: str, project_id: Optional[str] = None) -> Any:
"""
Roll back a transaction.
@ -178,14 +183,14 @@ class DatastoreHook(GoogleCloudBaseHook):
:param project_id: Google Cloud Platform project ID against which to make the request.
:type project_id: str
"""
conn = self.get_conn()
conn = self.get_conn() # type: Any
conn.projects().rollback( # pylint:disable=no-member
projectId=project_id, body={'transaction': transaction}
).execute(num_retries=self.num_retries)
@GoogleCloudBaseHook.fallback_to_default_project_id
def run_query(self, body, project_id=None):
def run_query(self, body: Dict, project_id: Optional[str] = None) -> Dict:
"""
Run a query for entities.
@ -199,7 +204,7 @@ class DatastoreHook(GoogleCloudBaseHook):
:return: the batch of query results.
:rtype: dict
"""
conn = self.get_conn()
conn = self.get_conn() # type: Any
resp = (conn # pylint:disable=no-member
.projects()
@ -208,7 +213,7 @@ class DatastoreHook(GoogleCloudBaseHook):
return resp['batch']
def get_operation(self, name):
def get_operation(self, name: str) -> Dict:
"""
Gets the latest state of a long-running operation.
@ -220,7 +225,7 @@ class DatastoreHook(GoogleCloudBaseHook):
:return: a resource operation instance.
:rtype: dict
"""
conn = self.get_conn()
conn = self.get_conn() # type: Any
resp = (conn # pylint:disable=no-member
.projects()
@ -230,7 +235,7 @@ class DatastoreHook(GoogleCloudBaseHook):
return resp
def delete_operation(self, name):
def delete_operation(self, name: str) -> Dict:
"""
Deletes the long-running operation.
@ -242,7 +247,7 @@ class DatastoreHook(GoogleCloudBaseHook):
:return: none if successful.
:rtype: dict
"""
conn = self.get_conn()
conn = self.get_conn() # type: Any
resp = (conn # pylint:disable=no-member
.projects()
@ -252,7 +257,7 @@ class DatastoreHook(GoogleCloudBaseHook):
return resp
def poll_operation_until_done(self, name, polling_interval_in_seconds):
def poll_operation_until_done(self, name: str, polling_interval_in_seconds: int) -> Dict:
"""
Poll backup operation state until it's completed.
@ -264,9 +269,9 @@ class DatastoreHook(GoogleCloudBaseHook):
:rtype: dict
"""
while True:
result = self.get_operation(name)
result = self.get_operation(name) # type: Dict
state = result['metadata']['common']['state']
state = result['metadata']['common']['state'] # type: str
if state == 'PROCESSING':
self.log.info('Operation is processing. Re-polling state in {} seconds'
.format(polling_interval_in_seconds))
@ -275,8 +280,12 @@ class DatastoreHook(GoogleCloudBaseHook):
return result
@GoogleCloudBaseHook.fallback_to_default_project_id
def export_to_storage_bucket(self, bucket, namespace=None, entity_filter=None,
labels=None, project_id=None):
def export_to_storage_bucket(self,
bucket: str,
namespace: Optional[str] = None,
entity_filter: Optional[Dict] = None,
labels: Optional[Dict[str, str]] = None,
project_id: Optional[str] = None) -> Dict:
"""
Export entities from Cloud Datastore to Cloud Storage for backup.
@ -299,9 +308,9 @@ class DatastoreHook(GoogleCloudBaseHook):
:return: a resource operation instance.
:rtype: dict
"""
admin_conn = self.get_conn()
admin_conn = self.get_conn() # type: Any
output_uri_prefix = 'gs://' + '/'.join(filter(None, [bucket, namespace]))
output_uri_prefix = 'gs://' + '/'.join(filter(None, [bucket, namespace])) # type: str
if not entity_filter:
entity_filter = {}
if not labels:
@ -310,7 +319,7 @@ class DatastoreHook(GoogleCloudBaseHook):
'outputUrlPrefix': output_uri_prefix,
'entityFilter': entity_filter,
'labels': labels,
}
} # type: Dict
resp = (admin_conn # pylint:disable=no-member
.projects()
.export(projectId=project_id, body=body)
@ -319,8 +328,13 @@ class DatastoreHook(GoogleCloudBaseHook):
return resp
@GoogleCloudBaseHook.fallback_to_default_project_id
def import_from_storage_bucket(self, bucket, file, namespace=None,
entity_filter=None, labels=None, project_id=None):
def import_from_storage_bucket(self,
bucket: str,
file: str,
namespace: Optional[str] = None,
entity_filter: Optional[Dict] = None,
labels: Optional[Union[Dict, str]] = None,
project_id: Optional[str] = None) -> Dict:
"""
Import a backup from Cloud Storage to Cloud Datastore.
@ -345,9 +359,9 @@ class DatastoreHook(GoogleCloudBaseHook):
:return: a resource operation instance.
:rtype: dict
"""
admin_conn = self.get_conn()
admin_conn = self.get_conn() # type: Any
input_url = 'gs://' + '/'.join(filter(None, [bucket, namespace, file]))
input_url = 'gs://' + '/'.join(filter(None, [bucket, namespace, file])) # type: str
if not entity_filter:
entity_filter = {}
if not labels:
@ -356,7 +370,7 @@ class DatastoreHook(GoogleCloudBaseHook):
'inputUrl': input_url,
'entityFilter': entity_filter,
'labels': labels,
}
} # type: Dict
resp = (admin_conn # pylint:disable=no-member
.projects()
.import_(projectId=project_id, body=body)

Просмотреть файл

@ -19,7 +19,7 @@
"""
This module contains a Google Pub/Sub Hook.
"""
from typing import Any, List, Dict, Optional
from uuid import uuid4
from googleapiclient.discovery import build
@ -53,7 +53,7 @@ class PubSubHook(GoogleCloudBaseHook):
def __init__(self, gcp_conn_id: str = 'google_cloud_default', delegate_to: str = None) -> None:
super().__init__(gcp_conn_id, delegate_to=delegate_to)
def get_conn(self):
def get_conn(self) -> Any:
"""
Returns a Pub/Sub service object.
@ -63,7 +63,7 @@ class PubSubHook(GoogleCloudBaseHook):
return build(
'pubsub', 'v1', http=http_authorized, cache_discovery=False)
def publish(self, project, topic, messages):
def publish(self, project: str, topic: str, messages: List[Dict]) -> None:
"""
Publishes messages to a Pub/Sub topic.
@ -87,7 +87,7 @@ class PubSubHook(GoogleCloudBaseHook):
raise PubSubException(
'Error publishing to topic {}'.format(full_topic), e)
def create_topic(self, project, topic, fail_if_exists=False):
def create_topic(self, project: str, topic: str, fail_if_exists: bool = False) -> None:
"""
Creates a Pub/Sub topic, if it does not already exist.
@ -117,7 +117,7 @@ class PubSubHook(GoogleCloudBaseHook):
raise PubSubException(
'Error creating topic {}'.format(full_topic), e)
def delete_topic(self, project, topic, fail_if_not_exists=False):
def delete_topic(self, project: str, topic: str, fail_if_not_exists: bool = False) -> None:
"""
Deletes a Pub/Sub topic if it exists.
@ -146,9 +146,15 @@ class PubSubHook(GoogleCloudBaseHook):
raise PubSubException(
'Error deleting topic {}'.format(full_topic), e)
def create_subscription(self, topic_project, topic, subscription=None,
subscription_project=None, ack_deadline_secs=10,
fail_if_exists=False):
def create_subscription(
self,
topic_project: str,
topic: str,
subscription: Optional[str] = None,
subscription_project: Optional[str] = None,
ack_deadline_secs: int = 10,
fail_if_exists: bool = False,
) -> str:
"""
Creates a Pub/Sub subscription, if it does not already exist.
@ -204,8 +210,7 @@ class PubSubHook(GoogleCloudBaseHook):
e)
return subscription
def delete_subscription(self, project, subscription,
fail_if_not_exists=False):
def delete_subscription(self, project: str, subscription: str, fail_if_not_exists: bool = False) -> None:
"""
Deletes a Pub/Sub subscription, if it exists.
@ -236,8 +241,9 @@ class PubSubHook(GoogleCloudBaseHook):
'Error deleting subscription {}'.format(full_subscription),
e)
def pull(self, project, subscription, max_messages,
return_immediately=False):
def pull(
self, project: str, subscription: str, max_messages: int, return_immediately: bool = False
) -> List[Dict]:
"""
Pulls up to ``max_messages`` messages from Pub/Sub subscription.
@ -273,7 +279,7 @@ class PubSubHook(GoogleCloudBaseHook):
'Error pulling messages from subscription {}'.format(
full_subscription), e)
def acknowledge(self, project, subscription, ack_ids):
def acknowledge(self, project: str, subscription: str, ack_ids: List) -> None:
"""
Pulls up to ``max_messages`` messages from Pub/Sub subscription.