Add Parquet data type to BaseSQLToGCSOperator (#13359)
This commit is contained in:
Родитель
10be37513c
Коммит
406181d64a
|
@ -22,6 +22,8 @@ import warnings
|
|||
from tempfile import NamedTemporaryFile
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import unicodecsv as csv
|
||||
|
||||
from airflow.models import BaseOperator
|
||||
|
@ -185,6 +187,8 @@ class BaseSQLToGCSOperator(BaseOperator):
|
|||
tmp_file_handle = NamedTemporaryFile(delete=True)
|
||||
if self.export_format == 'csv':
|
||||
file_mime_type = 'text/csv'
|
||||
elif self.export_format == 'parquet':
|
||||
file_mime_type = 'application/octet-stream'
|
||||
else:
|
||||
file_mime_type = 'application/json'
|
||||
files_to_upload = [
|
||||
|
@ -198,6 +202,9 @@ class BaseSQLToGCSOperator(BaseOperator):
|
|||
|
||||
if self.export_format == 'csv':
|
||||
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
|
||||
if self.export_format == 'parquet':
|
||||
parquet_schema = self._convert_parquet_schema(cursor)
|
||||
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
|
||||
|
||||
for row in cursor:
|
||||
# Convert datetime objects to utc seconds, and decimals to floats.
|
||||
|
@ -208,6 +215,12 @@ class BaseSQLToGCSOperator(BaseOperator):
|
|||
if self.null_marker is not None:
|
||||
row = [value if value is not None else self.null_marker for value in row]
|
||||
csv_writer.writerow(row)
|
||||
elif self.export_format == 'parquet':
|
||||
if self.null_marker is not None:
|
||||
row = [value if value is not None else self.null_marker for value in row]
|
||||
row_pydic = {col: [value] for col, value in zip(schema, row)}
|
||||
tbl = pa.Table.from_pydict(row_pydic, parquet_schema)
|
||||
parquet_writer.write_table(tbl)
|
||||
else:
|
||||
row_dict = dict(zip(schema, row))
|
||||
|
||||
|
@ -232,7 +245,8 @@ class BaseSQLToGCSOperator(BaseOperator):
|
|||
self.log.info("Current file count: %d", len(files_to_upload))
|
||||
if self.export_format == 'csv':
|
||||
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
|
||||
|
||||
if self.export_format == 'parquet':
|
||||
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
|
||||
return files_to_upload
|
||||
|
||||
def _configure_csv_file(self, file_handle, schema):
|
||||
|
@ -243,6 +257,30 @@ class BaseSQLToGCSOperator(BaseOperator):
|
|||
csv_writer.writerow(schema)
|
||||
return csv_writer
|
||||
|
||||
def _configure_parquet_file(self, file_handle, parquet_schema):
|
||||
parquet_writer = pq.ParquetWriter(file_handle.name, parquet_schema)
|
||||
return parquet_writer
|
||||
|
||||
def _convert_parquet_schema(self, cursor):
|
||||
type_map = {
|
||||
'INTERGER': pa.int64(),
|
||||
'FLOAT': pa.float64(),
|
||||
'NUMERIC': pa.float64(),
|
||||
'BIGNUMERIC': pa.float64(),
|
||||
'BOOL': pa.bool_(),
|
||||
'STRING': pa.string(),
|
||||
'BYTES': pa.binary(),
|
||||
'DATE': pa.date32(),
|
||||
'DATETIME': pa.date64(),
|
||||
'TIMESTAMP': pa.timestamp('s'),
|
||||
}
|
||||
|
||||
columns = [field[0] for field in cursor.description]
|
||||
bq_types = [self.field_to_bigquery(field) for field in cursor.description]
|
||||
pq_types = [type_map.get(bq_type, pa.string()) for bq_type in bq_types]
|
||||
parquet_schema = pa.schema(zip(columns, pq_types))
|
||||
return parquet_schema
|
||||
|
||||
@abc.abstractmethod
|
||||
def query(self):
|
||||
"""Execute DBAPI query."""
|
||||
|
|
|
@ -18,8 +18,9 @@
|
|||
import json
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pandas as pd
|
||||
import unicodecsv as csv
|
||||
|
||||
from airflow.providers.google.cloud.hooks.gcs import GCSHook
|
||||
|
@ -36,6 +37,11 @@ SCHEMA = [
|
|||
]
|
||||
COLUMNS = ["column_a", "column_b", "column_c"]
|
||||
ROW = ["convert_type_return_value", "convert_type_return_value", "convert_type_return_value"]
|
||||
CURSOR_DESCRIPTION = [
|
||||
("column_a", "3", 0, 0, 0, 0, False),
|
||||
("column_b", "253", 0, 0, 0, 0, False),
|
||||
("column_c", "10", 0, 0, 0, 0, False),
|
||||
]
|
||||
TMP_FILE_NAME = "temp-file"
|
||||
INPUT_DATA = [
|
||||
["101", "school", "2015-01-01"],
|
||||
|
@ -52,13 +58,15 @@ OUTPUT_DATA = json.dumps(
|
|||
SCHEMA_FILE = "schema_file.json"
|
||||
APP_JSON = "application/json"
|
||||
|
||||
OUTPUT_DF = pd.DataFrame([['convert_type_return_value'] * 3] * 3, columns=COLUMNS)
|
||||
|
||||
|
||||
class DummySQLToGCSOperator(BaseSQLToGCSOperator):
|
||||
def field_to_bigquery(self, field):
|
||||
pass
|
||||
|
||||
def convert_type(self, value, schema_type):
|
||||
pass
|
||||
return 'convert_type_return_value'
|
||||
|
||||
def query(self):
|
||||
pass
|
||||
|
@ -69,13 +77,10 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
|||
@mock.patch.object(csv.writer, "writerow")
|
||||
@mock.patch.object(GCSHook, "upload")
|
||||
@mock.patch.object(DummySQLToGCSOperator, "query")
|
||||
@mock.patch.object(DummySQLToGCSOperator, "field_to_bigquery")
|
||||
@mock.patch.object(DummySQLToGCSOperator, "convert_type")
|
||||
def test_exec(
|
||||
self, mock_convert_type, mock_field_to_bigquery, mock_query, mock_upload, mock_writerow, mock_tempfile
|
||||
):
|
||||
def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, mock_tempfile):
|
||||
cursor_mock = Mock()
|
||||
cursor_mock.description = [("column_a", "3"), ("column_b", "253"), ("column_c", "10")]
|
||||
cursor_mock.description = CURSOR_DESCRIPTION
|
||||
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
|
||||
mock_query.return_value = cursor_mock
|
||||
mock_convert_type.return_value = "convert_type_return_value"
|
||||
|
@ -99,6 +104,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
|||
|
||||
mock_tempfile.return_value = mock_file
|
||||
|
||||
# Test CSV
|
||||
operator = DummySQLToGCSOperator(
|
||||
sql=SQL,
|
||||
bucket=BUCKET,
|
||||
|
@ -109,7 +115,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
|||
export_format="csv",
|
||||
gzip=True,
|
||||
schema=SCHEMA,
|
||||
google_cloud_storage_conn_id='google_cloud_default',
|
||||
gcp_conn_id='google_cloud_default',
|
||||
)
|
||||
operator.execute(context=dict())
|
||||
|
||||
|
@ -140,6 +146,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
|||
|
||||
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
|
||||
|
||||
# Test JSON
|
||||
operator = DummySQLToGCSOperator(
|
||||
sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="json", schema=SCHEMA
|
||||
)
|
||||
|
@ -160,6 +167,27 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
|||
mock_upload.assert_called_once_with(BUCKET, FILENAME, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False)
|
||||
mock_close.assert_called_once()
|
||||
|
||||
mock_query.reset_mock()
|
||||
mock_flush.reset_mock()
|
||||
mock_upload.reset_mock()
|
||||
mock_close.reset_mock()
|
||||
cursor_mock.reset_mock()
|
||||
|
||||
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
|
||||
|
||||
# Test parquet
|
||||
operator = DummySQLToGCSOperator(
|
||||
sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="parquet", schema=SCHEMA
|
||||
)
|
||||
operator.execute(context=dict())
|
||||
|
||||
mock_query.assert_called_once()
|
||||
mock_flush.assert_called_once()
|
||||
mock_upload.assert_called_once_with(
|
||||
BUCKET, FILENAME, TMP_FILE_NAME, mime_type='application/octet-stream', gzip=False
|
||||
)
|
||||
mock_close.assert_called_once()
|
||||
|
||||
# Test null marker
|
||||
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
|
||||
mock_convert_type.return_value = None
|
||||
|
@ -182,3 +210,69 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
|
|||
mock.call(["NULL", "NULL", "NULL"]),
|
||||
]
|
||||
)
|
||||
|
||||
def test__write_local_data_files_csv(self):
|
||||
op = DummySQLToGCSOperator(
|
||||
sql=SQL,
|
||||
bucket=BUCKET,
|
||||
filename=FILENAME,
|
||||
task_id=TASK_ID,
|
||||
schema_filename=SCHEMA_FILE,
|
||||
export_format="csv",
|
||||
gzip=False,
|
||||
schema=SCHEMA,
|
||||
gcp_conn_id='google_cloud_default',
|
||||
)
|
||||
cursor = MagicMock()
|
||||
cursor.__iter__.return_value = INPUT_DATA
|
||||
cursor.description = CURSOR_DESCRIPTION
|
||||
|
||||
files = op._write_local_data_files(cursor)
|
||||
file = files[0]['file_handle']
|
||||
file.flush()
|
||||
df = pd.read_csv(file.name)
|
||||
assert df.equals(OUTPUT_DF)
|
||||
|
||||
def test__write_local_data_files_json(self):
|
||||
op = DummySQLToGCSOperator(
|
||||
sql=SQL,
|
||||
bucket=BUCKET,
|
||||
filename=FILENAME,
|
||||
task_id=TASK_ID,
|
||||
schema_filename=SCHEMA_FILE,
|
||||
export_format="json",
|
||||
gzip=False,
|
||||
schema=SCHEMA,
|
||||
gcp_conn_id='google_cloud_default',
|
||||
)
|
||||
cursor = MagicMock()
|
||||
cursor.__iter__.return_value = INPUT_DATA
|
||||
cursor.description = CURSOR_DESCRIPTION
|
||||
|
||||
files = op._write_local_data_files(cursor)
|
||||
file = files[0]['file_handle']
|
||||
file.flush()
|
||||
df = pd.read_json(file.name, orient='records', lines=True)
|
||||
assert df.equals(OUTPUT_DF)
|
||||
|
||||
def test__write_local_data_files_parquet(self):
|
||||
op = DummySQLToGCSOperator(
|
||||
sql=SQL,
|
||||
bucket=BUCKET,
|
||||
filename=FILENAME,
|
||||
task_id=TASK_ID,
|
||||
schema_filename=SCHEMA_FILE,
|
||||
export_format="parquet",
|
||||
gzip=False,
|
||||
schema=SCHEMA,
|
||||
gcp_conn_id='google_cloud_default',
|
||||
)
|
||||
cursor = MagicMock()
|
||||
cursor.__iter__.return_value = INPUT_DATA
|
||||
cursor.description = CURSOR_DESCRIPTION
|
||||
|
||||
files = op._write_local_data_files(cursor)
|
||||
file = files[0]['file_handle']
|
||||
file.flush()
|
||||
df = pd.read_parquet(file.name)
|
||||
assert df.equals(OUTPUT_DF)
|
||||
|
|
Загрузка…
Ссылка в новой задаче