[AIRFLOW-1267][AIRFLOW-1874] Add dialect parameter to BigQueryHook
Allows a default BigQuery dialect to be specified at the hook level, which is threaded through to the underlying cursors. This allows standard SQL dialect to be used, while maintaining compatibility with the `DbApiHook` interface. Addresses AIRFLOW-1267 and AIRFLOW-1874 Closes #2964 from ji-han/master
This commit is contained in:
Родитель
24bb2b7b6d
Коммит
1021f68031
|
@ -44,9 +44,13 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
|
|||
"""
|
||||
conn_name_attr = 'bigquery_conn_id'
|
||||
|
||||
def __init__(self, bigquery_conn_id='bigquery_default', delegate_to=None):
|
||||
def __init__(self,
|
||||
bigquery_conn_id='bigquery_default',
|
||||
delegate_to=None,
|
||||
use_legacy_sql=True):
|
||||
super(BigQueryHook, self).__init__(
|
||||
conn_id=bigquery_conn_id, delegate_to=delegate_to)
|
||||
self.use_legacy_sql = use_legacy_sql
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
|
@ -54,7 +58,10 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
|
|||
"""
|
||||
service = self.get_service()
|
||||
project = self._get_field('project')
|
||||
return BigQueryConnection(service=service, project_id=project)
|
||||
return BigQueryConnection(
|
||||
service=service,
|
||||
project_id=project,
|
||||
use_legacy_sql=self.use_legacy_sql)
|
||||
|
||||
def get_service(self):
|
||||
"""
|
||||
|
@ -71,7 +78,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_pandas_df(self, bql, parameters=None, dialect='legacy'):
|
||||
def get_pandas_df(self, bql, parameters=None, dialect=None):
|
||||
"""
|
||||
Returns a Pandas DataFrame for the results produced by a BigQuery
|
||||
query. The DbApiHook method must be overridden because Pandas
|
||||
|
@ -86,10 +93,15 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
|
|||
used, leave to override superclass method)
|
||||
:type parameters: mapping or iterable
|
||||
:param dialect: Dialect of BigQuery SQL – legacy SQL or standard SQL
|
||||
:type dialect: string in {'legacy', 'standard'}, default 'legacy'
|
||||
defaults to use `self.use_legacy_sql` if not specified
|
||||
:type dialect: string in {'legacy', 'standard'}
|
||||
"""
|
||||
service = self.get_service()
|
||||
project = self._get_field('project')
|
||||
|
||||
if dialect is None:
|
||||
dialect = 'legacy' if self.use_legacy_sql else 'standard'
|
||||
|
||||
connector = BigQueryPandasConnector(project, service, dialect=dialect)
|
||||
schema, pages = connector.run_query(bql)
|
||||
dataframe_list = []
|
||||
|
@ -188,9 +200,10 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
PEP 249 cursor isn't needed.
|
||||
"""
|
||||
|
||||
def __init__(self, service, project_id):
|
||||
def __init__(self, service, project_id, use_legacy_sql=True):
|
||||
self.service = service
|
||||
self.project_id = project_id
|
||||
self.use_legacy_sql = use_legacy_sql
|
||||
self.running_job_id = None
|
||||
|
||||
def run_query(self,
|
||||
|
@ -199,7 +212,6 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
write_disposition='WRITE_EMPTY',
|
||||
allow_large_results=False,
|
||||
udf_config=False,
|
||||
use_legacy_sql=True,
|
||||
maximum_billing_tier=None,
|
||||
create_disposition='CREATE_IF_NEEDED',
|
||||
query_params=None,
|
||||
|
@ -224,8 +236,6 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
:param udf_config: The User Defined Function configuration for the query.
|
||||
See https://cloud.google.com/bigquery/user-defined-functions for details.
|
||||
:type udf_config: list
|
||||
:param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false).
|
||||
:type use_legacy_sql: boolean
|
||||
:param maximum_billing_tier: Positive integer that serves as a
|
||||
multiplier of the basic price.
|
||||
:type maximum_billing_tier: integer
|
||||
|
@ -257,7 +267,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
configuration = {
|
||||
'query': {
|
||||
'query': bql,
|
||||
'useLegacySql': use_legacy_sql,
|
||||
'useLegacySql': self.use_legacy_sql,
|
||||
'maximumBillingTier': maximum_billing_tier
|
||||
}
|
||||
}
|
||||
|
@ -290,7 +300,7 @@ class BigQueryBaseCursor(LoggingMixin):
|
|||
})
|
||||
|
||||
if query_params:
|
||||
if use_legacy_sql:
|
||||
if self.use_legacy_sql:
|
||||
raise ValueError("Query paramaters are not allowed when using "
|
||||
"legacy SQL")
|
||||
else:
|
||||
|
@ -942,9 +952,11 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|||
https://github.com/dropbox/PyHive/blob/master/pyhive/common.py
|
||||
"""
|
||||
|
||||
def __init__(self, service, project_id):
|
||||
def __init__(self, service, project_id, use_legacy_sql=True):
|
||||
super(BigQueryCursor, self).__init__(
|
||||
service=service, project_id=project_id)
|
||||
service=service,
|
||||
project_id=project_id,
|
||||
use_legacy_sql=use_legacy_sql)
|
||||
self.buffersize = None
|
||||
self.page_token = None
|
||||
self.job_id = None
|
||||
|
|
|
@ -98,6 +98,7 @@ class BigQueryOperator(BaseOperator):
|
|||
self.log.info('Executing: %s', self.bql)
|
||||
hook = BigQueryHook(
|
||||
bigquery_conn_id=self.bigquery_conn_id,
|
||||
use_legacy_sql=self.use_legacy_sql,
|
||||
delegate_to=self.delegate_to)
|
||||
conn = hook.get_conn()
|
||||
self.bq_cursor = conn.cursor()
|
||||
|
@ -107,7 +108,6 @@ class BigQueryOperator(BaseOperator):
|
|||
write_disposition=self.write_disposition,
|
||||
allow_large_results=self.allow_large_results,
|
||||
udf_config=self.udf_config,
|
||||
use_legacy_sql=self.use_legacy_sql,
|
||||
maximum_billing_tier=self.maximum_billing_tier,
|
||||
create_disposition=self.create_disposition,
|
||||
query_params=self.query_params,
|
||||
|
|
|
@ -308,6 +308,25 @@ class TestTimePartitioningInRunJob(unittest.TestCase):
|
|||
)
|
||||
|
||||
|
||||
class TestBigQueryHookLegacySql(unittest.TestCase):
|
||||
"""Ensure `use_legacy_sql` param in `BigQueryHook` propagates properly."""
|
||||
|
||||
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
|
||||
def test_hook_uses_legacy_sql_by_default(self, run_with_config):
|
||||
with mock.patch.object(hook.BigQueryHook, 'get_service'):
|
||||
bq_hook = hook.BigQueryHook()
|
||||
bq_hook.get_first('query')
|
||||
args, kwargs = run_with_config.call_args
|
||||
self.assertIs(args[0]['query']['useLegacySql'], True)
|
||||
|
||||
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
|
||||
def test_legacy_sql_override_propagates_properly(self, run_with_config):
|
||||
with mock.patch.object(hook.BigQueryHook, 'get_service'):
|
||||
bq_hook = hook.BigQueryHook(use_legacy_sql=False)
|
||||
bq_hook.get_first('query')
|
||||
args, kwargs = run_with_config.call_args
|
||||
self.assertIs(args[0]['query']['useLegacySql'], False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче