[AIRFLOW-1267][AIRFLOW-1874] Add dialect parameter to BigQueryHook

Allows a default BigQuery dialect to be specified
at the hook level, which is threaded through to
the
underlying cursors.

This allows standard SQL dialect to be used,
while maintaining compatibility with the
`DbApiHook` interface.

Addresses AIRFLOW-1267 and AIRFLOW-1874

Closes #2964 from ji-han/master
This commit is contained in:
Winston Huang 2018-01-22 18:27:40 +01:00 коммит произвёл Fokko Driesprong
Родитель 24bb2b7b6d
Коммит 1021f68031
3 изменённых файлов: 45 добавлений и 14 удалений

Просмотреть файл

@ -44,9 +44,13 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
"""
conn_name_attr = 'bigquery_conn_id'
def __init__(self, bigquery_conn_id='bigquery_default', delegate_to=None):
def __init__(self,
bigquery_conn_id='bigquery_default',
delegate_to=None,
use_legacy_sql=True):
super(BigQueryHook, self).__init__(
conn_id=bigquery_conn_id, delegate_to=delegate_to)
self.use_legacy_sql = use_legacy_sql
def get_conn(self):
"""
@ -54,7 +58,10 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
"""
service = self.get_service()
project = self._get_field('project')
return BigQueryConnection(service=service, project_id=project)
return BigQueryConnection(
service=service,
project_id=project,
use_legacy_sql=self.use_legacy_sql)
def get_service(self):
"""
@ -71,7 +78,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
"""
raise NotImplementedError()
def get_pandas_df(self, bql, parameters=None, dialect='legacy'):
def get_pandas_df(self, bql, parameters=None, dialect=None):
"""
Returns a Pandas DataFrame for the results produced by a BigQuery
query. The DbApiHook method must be overridden because Pandas
@ -86,10 +93,15 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
used, leave to override superclass method)
:type parameters: mapping or iterable
:param dialect: Dialect of BigQuery SQL legacy SQL or standard SQL
:type dialect: string in {'legacy', 'standard'}, default 'legacy'
defaults to use `self.use_legacy_sql` if not specified
:type dialect: string in {'legacy', 'standard'}
"""
service = self.get_service()
project = self._get_field('project')
if dialect is None:
dialect = 'legacy' if self.use_legacy_sql else 'standard'
connector = BigQueryPandasConnector(project, service, dialect=dialect)
schema, pages = connector.run_query(bql)
dataframe_list = []
@ -188,9 +200,10 @@ class BigQueryBaseCursor(LoggingMixin):
PEP 249 cursor isn't needed.
"""
def __init__(self, service, project_id):
def __init__(self, service, project_id, use_legacy_sql=True):
self.service = service
self.project_id = project_id
self.use_legacy_sql = use_legacy_sql
self.running_job_id = None
def run_query(self,
@ -199,7 +212,6 @@ class BigQueryBaseCursor(LoggingMixin):
write_disposition='WRITE_EMPTY',
allow_large_results=False,
udf_config=False,
use_legacy_sql=True,
maximum_billing_tier=None,
create_disposition='CREATE_IF_NEEDED',
query_params=None,
@ -224,8 +236,6 @@ class BigQueryBaseCursor(LoggingMixin):
:param udf_config: The User Defined Function configuration for the query.
See https://cloud.google.com/bigquery/user-defined-functions for details.
:type udf_config: list
:param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false).
:type use_legacy_sql: boolean
:param maximum_billing_tier: Positive integer that serves as a
multiplier of the basic price.
:type maximum_billing_tier: integer
@ -257,7 +267,7 @@ class BigQueryBaseCursor(LoggingMixin):
configuration = {
'query': {
'query': bql,
'useLegacySql': use_legacy_sql,
'useLegacySql': self.use_legacy_sql,
'maximumBillingTier': maximum_billing_tier
}
}
@ -290,7 +300,7 @@ class BigQueryBaseCursor(LoggingMixin):
})
if query_params:
if use_legacy_sql:
if self.use_legacy_sql:
raise ValueError("Query paramaters are not allowed when using "
"legacy SQL")
else:
@ -942,9 +952,11 @@ class BigQueryCursor(BigQueryBaseCursor):
https://github.com/dropbox/PyHive/blob/master/pyhive/common.py
"""
def __init__(self, service, project_id):
def __init__(self, service, project_id, use_legacy_sql=True):
super(BigQueryCursor, self).__init__(
service=service, project_id=project_id)
service=service,
project_id=project_id,
use_legacy_sql=use_legacy_sql)
self.buffersize = None
self.page_token = None
self.job_id = None

Просмотреть файл

@ -98,6 +98,7 @@ class BigQueryOperator(BaseOperator):
self.log.info('Executing: %s', self.bql)
hook = BigQueryHook(
bigquery_conn_id=self.bigquery_conn_id,
use_legacy_sql=self.use_legacy_sql,
delegate_to=self.delegate_to)
conn = hook.get_conn()
self.bq_cursor = conn.cursor()
@ -107,7 +108,6 @@ class BigQueryOperator(BaseOperator):
write_disposition=self.write_disposition,
allow_large_results=self.allow_large_results,
udf_config=self.udf_config,
use_legacy_sql=self.use_legacy_sql,
maximum_billing_tier=self.maximum_billing_tier,
create_disposition=self.create_disposition,
query_params=self.query_params,

Просмотреть файл

@ -308,6 +308,25 @@ class TestTimePartitioningInRunJob(unittest.TestCase):
)
class TestBigQueryHookLegacySql(unittest.TestCase):
"""Ensure `use_legacy_sql` param in `BigQueryHook` propagates properly."""
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
def test_hook_uses_legacy_sql_by_default(self, run_with_config):
with mock.patch.object(hook.BigQueryHook, 'get_service'):
bq_hook = hook.BigQueryHook()
bq_hook.get_first('query')
args, kwargs = run_with_config.call_args
self.assertIs(args[0]['query']['useLegacySql'], True)
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
def test_legacy_sql_override_propagates_properly(self, run_with_config):
with mock.patch.object(hook.BigQueryHook, 'get_service'):
bq_hook = hook.BigQueryHook(use_legacy_sql=False)
bq_hook.get_first('query')
args, kwargs = run_with_config.call_args
self.assertIs(args[0]['query']['useLegacySql'], False)
if __name__ == '__main__':
unittest.main()