Add Parquet data type to BaseSQLToGCSOperator (#13359)

This commit is contained in:
Tuan Nguyen 2020-12-31 17:32:09 +07:00 коммит произвёл GitHub
Родитель 10be37513c
Коммит 406181d64a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 141 добавлений и 9 удалений

Просмотреть файл

@ -22,6 +22,8 @@ import warnings
from tempfile import NamedTemporaryFile
from typing import Optional, Sequence, Union
import pyarrow as pa
import pyarrow.parquet as pq
import unicodecsv as csv
from airflow.models import BaseOperator
@ -185,6 +187,8 @@ class BaseSQLToGCSOperator(BaseOperator):
tmp_file_handle = NamedTemporaryFile(delete=True)
if self.export_format == 'csv':
file_mime_type = 'text/csv'
elif self.export_format == 'parquet':
file_mime_type = 'application/octet-stream'
else:
file_mime_type = 'application/json'
files_to_upload = [
@ -198,6 +202,9 @@ class BaseSQLToGCSOperator(BaseOperator):
if self.export_format == 'csv':
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
if self.export_format == 'parquet':
parquet_schema = self._convert_parquet_schema(cursor)
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
for row in cursor:
# Convert datetime objects to utc seconds, and decimals to floats.
@ -208,6 +215,12 @@ class BaseSQLToGCSOperator(BaseOperator):
if self.null_marker is not None:
row = [value if value is not None else self.null_marker for value in row]
csv_writer.writerow(row)
elif self.export_format == 'parquet':
if self.null_marker is not None:
row = [value if value is not None else self.null_marker for value in row]
row_pydic = {col: [value] for col, value in zip(schema, row)}
tbl = pa.Table.from_pydict(row_pydic, parquet_schema)
parquet_writer.write_table(tbl)
else:
row_dict = dict(zip(schema, row))
@ -232,7 +245,8 @@ class BaseSQLToGCSOperator(BaseOperator):
self.log.info("Current file count: %d", len(files_to_upload))
if self.export_format == 'csv':
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
if self.export_format == 'parquet':
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
return files_to_upload
def _configure_csv_file(self, file_handle, schema):
@ -243,6 +257,30 @@ class BaseSQLToGCSOperator(BaseOperator):
csv_writer.writerow(schema)
return csv_writer
def _configure_parquet_file(self, file_handle, parquet_schema):
parquet_writer = pq.ParquetWriter(file_handle.name, parquet_schema)
return parquet_writer
def _convert_parquet_schema(self, cursor):
type_map = {
'INTERGER': pa.int64(),
'FLOAT': pa.float64(),
'NUMERIC': pa.float64(),
'BIGNUMERIC': pa.float64(),
'BOOL': pa.bool_(),
'STRING': pa.string(),
'BYTES': pa.binary(),
'DATE': pa.date32(),
'DATETIME': pa.date64(),
'TIMESTAMP': pa.timestamp('s'),
}
columns = [field[0] for field in cursor.description]
bq_types = [self.field_to_bigquery(field) for field in cursor.description]
pq_types = [type_map.get(bq_type, pa.string()) for bq_type in bq_types]
parquet_schema = pa.schema(zip(columns, pq_types))
return parquet_schema
@abc.abstractmethod
def query(self):
"""Execute DBAPI query."""

Просмотреть файл

@ -18,8 +18,9 @@
import json
import unittest
from unittest import mock
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock
import pandas as pd
import unicodecsv as csv
from airflow.providers.google.cloud.hooks.gcs import GCSHook
@ -36,6 +37,11 @@ SCHEMA = [
]
COLUMNS = ["column_a", "column_b", "column_c"]
ROW = ["convert_type_return_value", "convert_type_return_value", "convert_type_return_value"]
CURSOR_DESCRIPTION = [
("column_a", "3", 0, 0, 0, 0, False),
("column_b", "253", 0, 0, 0, 0, False),
("column_c", "10", 0, 0, 0, 0, False),
]
TMP_FILE_NAME = "temp-file"
INPUT_DATA = [
["101", "school", "2015-01-01"],
@ -52,13 +58,15 @@ OUTPUT_DATA = json.dumps(
SCHEMA_FILE = "schema_file.json"
APP_JSON = "application/json"
OUTPUT_DF = pd.DataFrame([['convert_type_return_value'] * 3] * 3, columns=COLUMNS)
class DummySQLToGCSOperator(BaseSQLToGCSOperator):
def field_to_bigquery(self, field):
pass
def convert_type(self, value, schema_type):
pass
return 'convert_type_return_value'
def query(self):
pass
@ -69,13 +77,10 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
@mock.patch.object(csv.writer, "writerow")
@mock.patch.object(GCSHook, "upload")
@mock.patch.object(DummySQLToGCSOperator, "query")
@mock.patch.object(DummySQLToGCSOperator, "field_to_bigquery")
@mock.patch.object(DummySQLToGCSOperator, "convert_type")
def test_exec(
self, mock_convert_type, mock_field_to_bigquery, mock_query, mock_upload, mock_writerow, mock_tempfile
):
def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, mock_tempfile):
cursor_mock = Mock()
cursor_mock.description = [("column_a", "3"), ("column_b", "253"), ("column_c", "10")]
cursor_mock.description = CURSOR_DESCRIPTION
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
mock_query.return_value = cursor_mock
mock_convert_type.return_value = "convert_type_return_value"
@ -99,6 +104,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
mock_tempfile.return_value = mock_file
# Test CSV
operator = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
@ -109,7 +115,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
export_format="csv",
gzip=True,
schema=SCHEMA,
google_cloud_storage_conn_id='google_cloud_default',
gcp_conn_id='google_cloud_default',
)
operator.execute(context=dict())
@ -140,6 +146,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
# Test JSON
operator = DummySQLToGCSOperator(
sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="json", schema=SCHEMA
)
@ -160,6 +167,27 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
mock_upload.assert_called_once_with(BUCKET, FILENAME, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False)
mock_close.assert_called_once()
mock_query.reset_mock()
mock_flush.reset_mock()
mock_upload.reset_mock()
mock_close.reset_mock()
cursor_mock.reset_mock()
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
# Test parquet
operator = DummySQLToGCSOperator(
sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="parquet", schema=SCHEMA
)
operator.execute(context=dict())
mock_query.assert_called_once()
mock_flush.assert_called_once()
mock_upload.assert_called_once_with(
BUCKET, FILENAME, TMP_FILE_NAME, mime_type='application/octet-stream', gzip=False
)
mock_close.assert_called_once()
# Test null marker
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
mock_convert_type.return_value = None
@ -182,3 +210,69 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
mock.call(["NULL", "NULL", "NULL"]),
]
)
def test__write_local_data_files_csv(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="csv",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION
files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file.flush()
df = pd.read_csv(file.name)
assert df.equals(OUTPUT_DF)
def test__write_local_data_files_json(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="json",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION
files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file.flush()
df = pd.read_json(file.name, orient='records', lines=True)
assert df.equals(OUTPUT_DF)
def test__write_local_data_files_parquet(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="parquet",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION
files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file.flush()
df = pd.read_parquet(file.name)
assert df.equals(OUTPUT_DF)