Add Parquet data type to BaseSQLToGCSOperator (#13359)

This commit is contained in:
Tuan Nguyen 2020-12-31 17:32:09 +07:00 коммит произвёл GitHub
Родитель 10be37513c
Коммит 406181d64a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 141 добавлений и 9 удалений

Просмотреть файл

@ -22,6 +22,8 @@ import warnings
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Optional, Sequence, Union from typing import Optional, Sequence, Union
import pyarrow as pa
import pyarrow.parquet as pq
import unicodecsv as csv import unicodecsv as csv
from airflow.models import BaseOperator from airflow.models import BaseOperator
@ -185,6 +187,8 @@ class BaseSQLToGCSOperator(BaseOperator):
tmp_file_handle = NamedTemporaryFile(delete=True) tmp_file_handle = NamedTemporaryFile(delete=True)
if self.export_format == 'csv': if self.export_format == 'csv':
file_mime_type = 'text/csv' file_mime_type = 'text/csv'
elif self.export_format == 'parquet':
file_mime_type = 'application/octet-stream'
else: else:
file_mime_type = 'application/json' file_mime_type = 'application/json'
files_to_upload = [ files_to_upload = [
@ -198,6 +202,9 @@ class BaseSQLToGCSOperator(BaseOperator):
if self.export_format == 'csv': if self.export_format == 'csv':
csv_writer = self._configure_csv_file(tmp_file_handle, schema) csv_writer = self._configure_csv_file(tmp_file_handle, schema)
if self.export_format == 'parquet':
parquet_schema = self._convert_parquet_schema(cursor)
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
for row in cursor: for row in cursor:
# Convert datetime objects to utc seconds, and decimals to floats. # Convert datetime objects to utc seconds, and decimals to floats.
@ -208,6 +215,12 @@ class BaseSQLToGCSOperator(BaseOperator):
if self.null_marker is not None: if self.null_marker is not None:
row = [value if value is not None else self.null_marker for value in row] row = [value if value is not None else self.null_marker for value in row]
csv_writer.writerow(row) csv_writer.writerow(row)
elif self.export_format == 'parquet':
if self.null_marker is not None:
row = [value if value is not None else self.null_marker for value in row]
row_pydic = {col: [value] for col, value in zip(schema, row)}
tbl = pa.Table.from_pydict(row_pydic, parquet_schema)
parquet_writer.write_table(tbl)
else: else:
row_dict = dict(zip(schema, row)) row_dict = dict(zip(schema, row))
@ -232,7 +245,8 @@ class BaseSQLToGCSOperator(BaseOperator):
self.log.info("Current file count: %d", len(files_to_upload)) self.log.info("Current file count: %d", len(files_to_upload))
if self.export_format == 'csv': if self.export_format == 'csv':
csv_writer = self._configure_csv_file(tmp_file_handle, schema) csv_writer = self._configure_csv_file(tmp_file_handle, schema)
if self.export_format == 'parquet':
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
return files_to_upload return files_to_upload
def _configure_csv_file(self, file_handle, schema): def _configure_csv_file(self, file_handle, schema):
@ -243,6 +257,30 @@ class BaseSQLToGCSOperator(BaseOperator):
csv_writer.writerow(schema) csv_writer.writerow(schema)
return csv_writer return csv_writer
def _configure_parquet_file(self, file_handle, parquet_schema):
parquet_writer = pq.ParquetWriter(file_handle.name, parquet_schema)
return parquet_writer
def _convert_parquet_schema(self, cursor):
type_map = {
'INTERGER': pa.int64(),
'FLOAT': pa.float64(),
'NUMERIC': pa.float64(),
'BIGNUMERIC': pa.float64(),
'BOOL': pa.bool_(),
'STRING': pa.string(),
'BYTES': pa.binary(),
'DATE': pa.date32(),
'DATETIME': pa.date64(),
'TIMESTAMP': pa.timestamp('s'),
}
columns = [field[0] for field in cursor.description]
bq_types = [self.field_to_bigquery(field) for field in cursor.description]
pq_types = [type_map.get(bq_type, pa.string()) for bq_type in bq_types]
parquet_schema = pa.schema(zip(columns, pq_types))
return parquet_schema
@abc.abstractmethod @abc.abstractmethod
def query(self): def query(self):
"""Execute DBAPI query.""" """Execute DBAPI query."""

Просмотреть файл

@ -18,8 +18,9 @@
import json import json
import unittest import unittest
from unittest import mock from unittest import mock
from unittest.mock import Mock from unittest.mock import MagicMock, Mock
import pandas as pd
import unicodecsv as csv import unicodecsv as csv
from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.cloud.hooks.gcs import GCSHook
@ -36,6 +37,11 @@ SCHEMA = [
] ]
COLUMNS = ["column_a", "column_b", "column_c"] COLUMNS = ["column_a", "column_b", "column_c"]
ROW = ["convert_type_return_value", "convert_type_return_value", "convert_type_return_value"] ROW = ["convert_type_return_value", "convert_type_return_value", "convert_type_return_value"]
CURSOR_DESCRIPTION = [
("column_a", "3", 0, 0, 0, 0, False),
("column_b", "253", 0, 0, 0, 0, False),
("column_c", "10", 0, 0, 0, 0, False),
]
TMP_FILE_NAME = "temp-file" TMP_FILE_NAME = "temp-file"
INPUT_DATA = [ INPUT_DATA = [
["101", "school", "2015-01-01"], ["101", "school", "2015-01-01"],
@ -52,13 +58,15 @@ OUTPUT_DATA = json.dumps(
SCHEMA_FILE = "schema_file.json" SCHEMA_FILE = "schema_file.json"
APP_JSON = "application/json" APP_JSON = "application/json"
OUTPUT_DF = pd.DataFrame([['convert_type_return_value'] * 3] * 3, columns=COLUMNS)
class DummySQLToGCSOperator(BaseSQLToGCSOperator): class DummySQLToGCSOperator(BaseSQLToGCSOperator):
def field_to_bigquery(self, field): def field_to_bigquery(self, field):
pass pass
def convert_type(self, value, schema_type): def convert_type(self, value, schema_type):
pass return 'convert_type_return_value'
def query(self): def query(self):
pass pass
@ -69,13 +77,10 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
@mock.patch.object(csv.writer, "writerow") @mock.patch.object(csv.writer, "writerow")
@mock.patch.object(GCSHook, "upload") @mock.patch.object(GCSHook, "upload")
@mock.patch.object(DummySQLToGCSOperator, "query") @mock.patch.object(DummySQLToGCSOperator, "query")
@mock.patch.object(DummySQLToGCSOperator, "field_to_bigquery")
@mock.patch.object(DummySQLToGCSOperator, "convert_type") @mock.patch.object(DummySQLToGCSOperator, "convert_type")
def test_exec( def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, mock_tempfile):
self, mock_convert_type, mock_field_to_bigquery, mock_query, mock_upload, mock_writerow, mock_tempfile
):
cursor_mock = Mock() cursor_mock = Mock()
cursor_mock.description = [("column_a", "3"), ("column_b", "253"), ("column_c", "10")] cursor_mock.description = CURSOR_DESCRIPTION
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA)) cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
mock_query.return_value = cursor_mock mock_query.return_value = cursor_mock
mock_convert_type.return_value = "convert_type_return_value" mock_convert_type.return_value = "convert_type_return_value"
@ -99,6 +104,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
mock_tempfile.return_value = mock_file mock_tempfile.return_value = mock_file
# Test CSV
operator = DummySQLToGCSOperator( operator = DummySQLToGCSOperator(
sql=SQL, sql=SQL,
bucket=BUCKET, bucket=BUCKET,
@ -109,7 +115,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
export_format="csv", export_format="csv",
gzip=True, gzip=True,
schema=SCHEMA, schema=SCHEMA,
google_cloud_storage_conn_id='google_cloud_default', gcp_conn_id='google_cloud_default',
) )
operator.execute(context=dict()) operator.execute(context=dict())
@ -140,6 +146,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA)) cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
# Test JSON
operator = DummySQLToGCSOperator( operator = DummySQLToGCSOperator(
sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="json", schema=SCHEMA sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="json", schema=SCHEMA
) )
@ -160,6 +167,27 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
mock_upload.assert_called_once_with(BUCKET, FILENAME, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False) mock_upload.assert_called_once_with(BUCKET, FILENAME, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False)
mock_close.assert_called_once() mock_close.assert_called_once()
mock_query.reset_mock()
mock_flush.reset_mock()
mock_upload.reset_mock()
mock_close.reset_mock()
cursor_mock.reset_mock()
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
# Test parquet
operator = DummySQLToGCSOperator(
sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="parquet", schema=SCHEMA
)
operator.execute(context=dict())
mock_query.assert_called_once()
mock_flush.assert_called_once()
mock_upload.assert_called_once_with(
BUCKET, FILENAME, TMP_FILE_NAME, mime_type='application/octet-stream', gzip=False
)
mock_close.assert_called_once()
# Test null marker # Test null marker
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA)) cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
mock_convert_type.return_value = None mock_convert_type.return_value = None
@ -182,3 +210,69 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
mock.call(["NULL", "NULL", "NULL"]), mock.call(["NULL", "NULL", "NULL"]),
] ]
) )
def test__write_local_data_files_csv(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="csv",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION
files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file.flush()
df = pd.read_csv(file.name)
assert df.equals(OUTPUT_DF)
def test__write_local_data_files_json(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="json",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION
files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file.flush()
df = pd.read_json(file.name, orient='records', lines=True)
assert df.equals(OUTPUT_DF)
def test__write_local_data_files_parquet(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="parquet",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION
files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file.flush()
df = pd.read_parquet(file.name)
assert df.equals(OUTPUT_DF)