Add Parquet data type to BaseSQLToGCSOperator (#13359)
This commit is contained in:
Родитель
10be37513c
Коммит
406181d64a
|
@ -22,6 +22,8 @@ import warnings
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Optional, Sequence, Union
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
import unicodecsv as csv
|
import unicodecsv as csv
|
||||||
|
|
||||||
from airflow.models import BaseOperator
|
from airflow.models import BaseOperator
|
||||||
|
@ -185,6 +187,8 @@ class BaseSQLToGCSOperator(BaseOperator):
|
||||||
tmp_file_handle = NamedTemporaryFile(delete=True)
|
tmp_file_handle = NamedTemporaryFile(delete=True)
|
||||||
if self.export_format == 'csv':
|
if self.export_format == 'csv':
|
||||||
file_mime_type = 'text/csv'
|
file_mime_type = 'text/csv'
|
||||||
|
elif self.export_format == 'parquet':
|
||||||
|
file_mime_type = 'application/octet-stream'
|
||||||
else:
|
else:
|
||||||
file_mime_type = 'application/json'
|
file_mime_type = 'application/json'
|
||||||
files_to_upload = [
|
files_to_upload = [
|
||||||
|
@ -198,6 +202,9 @@ class BaseSQLToGCSOperator(BaseOperator):
|
||||||
|
|
||||||
if self.export_format == 'csv':
|
if self.export_format == 'csv':
|
||||||
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
|
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
|
||||||
|
if self.export_format == 'parquet':
|
||||||
|
parquet_schema = self._convert_parquet_schema(cursor)
|
||||||
|
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
|
||||||
|
|
||||||
for row in cursor:
|
for row in cursor:
|
||||||
# Convert datetime objects to utc seconds, and decimals to floats.
|
# Convert datetime objects to utc seconds, and decimals to floats.
|
||||||
|
@ -208,6 +215,12 @@ class BaseSQLToGCSOperator(BaseOperator):
|
||||||
if self.null_marker is not None:
|
if self.null_marker is not None:
|
||||||
row = [value if value is not None else self.null_marker for value in row]
|
row = [value if value is not None else self.null_marker for value in row]
|
||||||
csv_writer.writerow(row)
|
csv_writer.writerow(row)
|
||||||
|
elif self.export_format == 'parquet':
|
||||||
|
if self.null_marker is not None:
|
||||||
|
row = [value if value is not None else self.null_marker for value in row]
|
||||||
|
row_pydic = {col: [value] for col, value in zip(schema, row)}
|
||||||
|
tbl = pa.Table.from_pydict(row_pydic, parquet_schema)
|
||||||
|
parquet_writer.write_table(tbl)
|
||||||
else:
|
else:
|
||||||
row_dict = dict(zip(schema, row))
|
row_dict = dict(zip(schema, row))
|
||||||
|
|
||||||
|
@ -232,7 +245,8 @@ class BaseSQLToGCSOperator(BaseOperator):
|
||||||
self.log.info("Current file count: %d", len(files_to_upload))
|
self.log.info("Current file count: %d", len(files_to_upload))
|
||||||
if self.export_format == 'csv':
|
if self.export_format == 'csv':
|
||||||
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
|
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
|
||||||
|
if self.export_format == 'parquet':
|
||||||
|
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
|
||||||
return files_to_upload
|
return files_to_upload
|
||||||
|
|
||||||
def _configure_csv_file(self, file_handle, schema):
|
def _configure_csv_file(self, file_handle, schema):
|
||||||
|
@ -243,6 +257,30 @@ class BaseSQLToGCSOperator(BaseOperator):
|
||||||
csv_writer.writerow(schema)
|
csv_writer.writerow(schema)
|
||||||
return csv_writer
|
return csv_writer
|
||||||
|
|
||||||
|
def _configure_parquet_file(self, file_handle, parquet_schema):
|
||||||
|
parquet_writer = pq.ParquetWriter(file_handle.name, parquet_schema)
|
||||||
|
return parquet_writer
|
||||||
|
|
||||||
|
def _convert_parquet_schema(self, cursor):
|
||||||
|
type_map = {
|
||||||
|
'INTERGER': pa.int64(),
|
||||||
|
'FLOAT': pa.float64(),
|
||||||
|
'NUMERIC': pa.float64(),
|
||||||
|
'BIGNUMERIC': pa.float64(),
|
||||||
|
'BOOL': pa.bool_(),
|
||||||
|
'STRING': pa.string(),
|
||||||
|
'BYTES': pa.binary(),
|
||||||
|
'DATE': pa.date32(),
|
||||||
|
'DATETIME': pa.date64(),
|
||||||
|
'TIMESTAMP': pa.timestamp('s'),
|
||||||
|
}
|
||||||
|
|
||||||
|
columns = [field[0] for field in cursor.description]
|
||||||
|
bq_types = [self.field_to_bigquery(field) for field in cursor.description]
|
||||||
|
pq_types = [type_map.get(bq_type, pa.string()) for bq_type in bq_types]
|
||||||
|
parquet_schema = pa.schema(zip(columns, pq_types))
|
||||||
|
return parquet_schema
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def query(self):
|
def query(self):
|
||||||
"""Execute DBAPI query."""
|
"""Execute DBAPI query."""
|
||||||
|
|
|
@ -18,8 +18,9 @@
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import Mock
|
from unittest.mock import MagicMock, Mock
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import unicodecsv as csv
|
import unicodecsv as csv
|
||||||
|
|
||||||
from airflow.providers.google.cloud.hooks.gcs import GCSHook
|
from airflow.providers.google.cloud.hooks.gcs import GCSHook
|
||||||
|
@ -36,6 +37,11 @@ SCHEMA = [
|
||||||
]
|
]
|
||||||
COLUMNS = ["column_a", "column_b", "column_c"]
|
COLUMNS = ["column_a", "column_b", "column_c"]
|
||||||
ROW = ["convert_type_return_value", "convert_type_return_value", "convert_type_return_value"]
|
ROW = ["convert_type_return_value", "convert_type_return_value", "convert_type_return_value"]
|
||||||
|
CURSOR_DESCRIPTION = [
|
||||||
|
("column_a", "3", 0, 0, 0, 0, False),
|
||||||
|
("column_b", "253", 0, 0, 0, 0, False),
|
||||||
|
("column_c", "10", 0, 0, 0, 0, False),
|
||||||
|
]
|
||||||
TMP_FILE_NAME = "temp-file"
|
TMP_FILE_NAME = "temp-file"
|
||||||
INPUT_DATA = [
|
INPUT_DATA = [
|
||||||
["101", "school", "2015-01-01"],
|
["101", "school", "2015-01-01"],
|
||||||
|
@ -52,13 +58,15 @@ OUTPUT_DATA = json.dumps(
|
||||||
SCHEMA_FILE = "schema_file.json"
|
SCHEMA_FILE = "schema_file.json"
|
||||||
APP_JSON = "application/json"
|
APP_JSON = "application/json"
|
||||||
|
|
||||||
|
OUTPUT_DF = pd.DataFrame([['convert_type_return_value'] * 3] * 3, columns=COLUMNS)
|
||||||
|
|
||||||
|
|
||||||
class DummySQLToGCSOperator(BaseSQLToGCSOperator):
|
class DummySQLToGCSOperator(BaseSQLToGCSOperator):
|
||||||
def field_to_bigquery(self, field):
|
def field_to_bigquery(self, field):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def convert_type(self, value, schema_type):
|
def convert_type(self, value, schema_type):
|
||||||
pass
|
return 'convert_type_return_value'
|
||||||
|
|
||||||
def query(self):
|
def query(self):
|
||||||
pass
|
pass
|
||||||
|
@ -69,13 +77,10 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
||||||
@mock.patch.object(csv.writer, "writerow")
|
@mock.patch.object(csv.writer, "writerow")
|
||||||
@mock.patch.object(GCSHook, "upload")
|
@mock.patch.object(GCSHook, "upload")
|
||||||
@mock.patch.object(DummySQLToGCSOperator, "query")
|
@mock.patch.object(DummySQLToGCSOperator, "query")
|
||||||
@mock.patch.object(DummySQLToGCSOperator, "field_to_bigquery")
|
|
||||||
@mock.patch.object(DummySQLToGCSOperator, "convert_type")
|
@mock.patch.object(DummySQLToGCSOperator, "convert_type")
|
||||||
def test_exec(
|
def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, mock_tempfile):
|
||||||
self, mock_convert_type, mock_field_to_bigquery, mock_query, mock_upload, mock_writerow, mock_tempfile
|
|
||||||
):
|
|
||||||
cursor_mock = Mock()
|
cursor_mock = Mock()
|
||||||
cursor_mock.description = [("column_a", "3"), ("column_b", "253"), ("column_c", "10")]
|
cursor_mock.description = CURSOR_DESCRIPTION
|
||||||
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
|
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
|
||||||
mock_query.return_value = cursor_mock
|
mock_query.return_value = cursor_mock
|
||||||
mock_convert_type.return_value = "convert_type_return_value"
|
mock_convert_type.return_value = "convert_type_return_value"
|
||||||
|
@ -99,6 +104,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
||||||
|
|
||||||
mock_tempfile.return_value = mock_file
|
mock_tempfile.return_value = mock_file
|
||||||
|
|
||||||
|
# Test CSV
|
||||||
operator = DummySQLToGCSOperator(
|
operator = DummySQLToGCSOperator(
|
||||||
sql=SQL,
|
sql=SQL,
|
||||||
bucket=BUCKET,
|
bucket=BUCKET,
|
||||||
|
@ -109,7 +115,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
||||||
export_format="csv",
|
export_format="csv",
|
||||||
gzip=True,
|
gzip=True,
|
||||||
schema=SCHEMA,
|
schema=SCHEMA,
|
||||||
google_cloud_storage_conn_id='google_cloud_default',
|
gcp_conn_id='google_cloud_default',
|
||||||
)
|
)
|
||||||
operator.execute(context=dict())
|
operator.execute(context=dict())
|
||||||
|
|
||||||
|
@ -140,6 +146,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
||||||
|
|
||||||
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
|
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
|
||||||
|
|
||||||
|
# Test JSON
|
||||||
operator = DummySQLToGCSOperator(
|
operator = DummySQLToGCSOperator(
|
||||||
sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="json", schema=SCHEMA
|
sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="json", schema=SCHEMA
|
||||||
)
|
)
|
||||||
|
@ -160,6 +167,27 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
||||||
mock_upload.assert_called_once_with(BUCKET, FILENAME, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False)
|
mock_upload.assert_called_once_with(BUCKET, FILENAME, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False)
|
||||||
mock_close.assert_called_once()
|
mock_close.assert_called_once()
|
||||||
|
|
||||||
|
mock_query.reset_mock()
|
||||||
|
mock_flush.reset_mock()
|
||||||
|
mock_upload.reset_mock()
|
||||||
|
mock_close.reset_mock()
|
||||||
|
cursor_mock.reset_mock()
|
||||||
|
|
||||||
|
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
|
||||||
|
|
||||||
|
# Test parquet
|
||||||
|
operator = DummySQLToGCSOperator(
|
||||||
|
sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="parquet", schema=SCHEMA
|
||||||
|
)
|
||||||
|
operator.execute(context=dict())
|
||||||
|
|
||||||
|
mock_query.assert_called_once()
|
||||||
|
mock_flush.assert_called_once()
|
||||||
|
mock_upload.assert_called_once_with(
|
||||||
|
BUCKET, FILENAME, TMP_FILE_NAME, mime_type='application/octet-stream', gzip=False
|
||||||
|
)
|
||||||
|
mock_close.assert_called_once()
|
||||||
|
|
||||||
# Test null marker
|
# Test null marker
|
||||||
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
|
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
|
||||||
mock_convert_type.return_value = None
|
mock_convert_type.return_value = None
|
||||||
|
@ -182,3 +210,69 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
||||||
mock.call(["NULL", "NULL", "NULL"]),
|
mock.call(["NULL", "NULL", "NULL"]),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test__write_local_data_files_csv(self):
|
||||||
|
op = DummySQLToGCSOperator(
|
||||||
|
sql=SQL,
|
||||||
|
bucket=BUCKET,
|
||||||
|
filename=FILENAME,
|
||||||
|
task_id=TASK_ID,
|
||||||
|
schema_filename=SCHEMA_FILE,
|
||||||
|
export_format="csv",
|
||||||
|
gzip=False,
|
||||||
|
schema=SCHEMA,
|
||||||
|
gcp_conn_id='google_cloud_default',
|
||||||
|
)
|
||||||
|
cursor = MagicMock()
|
||||||
|
cursor.__iter__.return_value = INPUT_DATA
|
||||||
|
cursor.description = CURSOR_DESCRIPTION
|
||||||
|
|
||||||
|
files = op._write_local_data_files(cursor)
|
||||||
|
file = files[0]['file_handle']
|
||||||
|
file.flush()
|
||||||
|
df = pd.read_csv(file.name)
|
||||||
|
assert df.equals(OUTPUT_DF)
|
||||||
|
|
||||||
|
def test__write_local_data_files_json(self):
|
||||||
|
op = DummySQLToGCSOperator(
|
||||||
|
sql=SQL,
|
||||||
|
bucket=BUCKET,
|
||||||
|
filename=FILENAME,
|
||||||
|
task_id=TASK_ID,
|
||||||
|
schema_filename=SCHEMA_FILE,
|
||||||
|
export_format="json",
|
||||||
|
gzip=False,
|
||||||
|
schema=SCHEMA,
|
||||||
|
gcp_conn_id='google_cloud_default',
|
||||||
|
)
|
||||||
|
cursor = MagicMock()
|
||||||
|
cursor.__iter__.return_value = INPUT_DATA
|
||||||
|
cursor.description = CURSOR_DESCRIPTION
|
||||||
|
|
||||||
|
files = op._write_local_data_files(cursor)
|
||||||
|
file = files[0]['file_handle']
|
||||||
|
file.flush()
|
||||||
|
df = pd.read_json(file.name, orient='records', lines=True)
|
||||||
|
assert df.equals(OUTPUT_DF)
|
||||||
|
|
||||||
|
def test__write_local_data_files_parquet(self):
|
||||||
|
op = DummySQLToGCSOperator(
|
||||||
|
sql=SQL,
|
||||||
|
bucket=BUCKET,
|
||||||
|
filename=FILENAME,
|
||||||
|
task_id=TASK_ID,
|
||||||
|
schema_filename=SCHEMA_FILE,
|
||||||
|
export_format="parquet",
|
||||||
|
gzip=False,
|
||||||
|
schema=SCHEMA,
|
||||||
|
gcp_conn_id='google_cloud_default',
|
||||||
|
)
|
||||||
|
cursor = MagicMock()
|
||||||
|
cursor.__iter__.return_value = INPUT_DATA
|
||||||
|
cursor.description = CURSOR_DESCRIPTION
|
||||||
|
|
||||||
|
files = op._write_local_data_files(cursor)
|
||||||
|
file = files[0]['file_handle']
|
||||||
|
file.flush()
|
||||||
|
df = pd.read_parquet(file.name)
|
||||||
|
assert df.equals(OUTPUT_DF)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче