[ARIFLOW-2458] Add cassandra-to-gcs operator
Closes #3354 from jgao54/cassandra-to-gcs
This commit is contained in:
Родитель
8873a8df8a
Коммит
f5115b7e6a
|
@ -0,0 +1,88 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
|
||||
TokenAwarePolicy, HostFilterPolicy,
|
||||
WhiteListRoundRobinPolicy)
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
|
||||
from airflow.hooks.base_hook import BaseHook
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
|
||||
|
||||
class CassandraHook(BaseHook, LoggingMixin):
|
||||
"""
|
||||
Hook used to interact with Cassandra
|
||||
|
||||
Contact_points can be specified as a comma-separated string in the 'hosts'
|
||||
field of the connection. Port can be specified in the port field of the
|
||||
connection. Load_alancing_policy, ssl_options, cql_version can be specified
|
||||
in the extra field of the connection.
|
||||
|
||||
For details of the Cluster config, see cassandra.cluster for more details.
|
||||
"""
|
||||
def __init__(self, cassandra_conn_id='cassandra_default'):
|
||||
conn = self.get_connection(cassandra_conn_id)
|
||||
|
||||
conn_config = {}
|
||||
if conn.host:
|
||||
conn_config['contact_points'] = conn.host.split(',')
|
||||
|
||||
if conn.port:
|
||||
conn_config['port'] = int(conn.port)
|
||||
|
||||
if conn.login:
|
||||
conn_config['auth_provider'] = PlainTextAuthProvider(
|
||||
username=conn.login, password=conn.password)
|
||||
|
||||
lb_policy = self.get_policy(conn.extra_dejson.get('load_balancing_policy', None))
|
||||
if lb_policy:
|
||||
conn_config['load_balancing_policy'] = lb_policy
|
||||
|
||||
cql_version = conn.extra_dejson.get('cql_version', None)
|
||||
if cql_version:
|
||||
conn_config['cql_version'] = cql_version
|
||||
|
||||
ssl_options = conn.extra_dejson.get('ssl_options', None)
|
||||
if ssl_options:
|
||||
conn_config['ssl_options'] = ssl_options
|
||||
|
||||
self.cluster = Cluster(**conn_config)
|
||||
self.keyspace = conn.schema
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
Returns a cassandra connection object
|
||||
"""
|
||||
return self.cluster.connect(self.keyspace)
|
||||
|
||||
def get_cluster(self):
|
||||
return self.cluster
|
||||
|
||||
@classmethod
|
||||
def get_policy(cls, policy_name):
|
||||
policies = {
|
||||
'RoundRobinPolicy': RoundRobinPolicy,
|
||||
'DCAwareRoundRobinPolicy': DCAwareRoundRobinPolicy,
|
||||
'TokenAwarePolicy': TokenAwarePolicy,
|
||||
'HostFilterPolicy': HostFilterPolicy,
|
||||
'WhiteListRoundRobinPolicy': WhiteListRoundRobinPolicy,
|
||||
}
|
||||
return policies.get(policy_name)
|
|
@ -0,0 +1,351 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
from builtins import str
|
||||
from base64 import b64encode
|
||||
from cassandra.util import Date, Time, SortedSet, OrderedMapSerializedKey
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from six import text_type, binary_type, PY3
|
||||
from tempfile import NamedTemporaryFile
|
||||
from uuid import UUID
|
||||
|
||||
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
|
||||
from airflow.contrib.hooks.cassandra_hook import CassandraHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class CassandraToGoogleCloudStorageOperator(BaseOperator):
|
||||
"""
|
||||
Copy data from Cassandra to Google cloud storage in JSON format
|
||||
|
||||
Note: Arrays of arrays are not supported.
|
||||
"""
|
||||
template_fields = ('cql', 'bucket', 'filename', 'schema_filename',)
|
||||
template_ext = ('.cql',)
|
||||
ui_color = '#a0e08c'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
cql,
|
||||
bucket,
|
||||
filename,
|
||||
schema_filename=None,
|
||||
approx_max_file_size_bytes=1900000000,
|
||||
cassandra_conn_id='cassandra_default',
|
||||
google_cloud_storage_conn_id='google_cloud_default',
|
||||
delegate_to=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""
|
||||
:param cql: The CQL to execute on the Cassandra table.
|
||||
:type cql: string
|
||||
:param bucket: The bucket to upload to.
|
||||
:type bucket: string
|
||||
:param filename: The filename to use as the object name when uploading
|
||||
to Google cloud storage. A {} should be specified in the filename
|
||||
to allow the operator to inject file numbers in cases where the
|
||||
file is split due to size.
|
||||
:type filename: string
|
||||
:param schema_filename: If set, the filename to use as the object name
|
||||
when uploading a .json file containing the BigQuery schema fields
|
||||
for the table that was dumped from MySQL.
|
||||
:type schema_filename: string
|
||||
:param approx_max_file_size_bytes: This operator supports the ability
|
||||
to split large table dumps into multiple files (see notes in the
|
||||
filenamed param docs above). Google cloud storage allows for files
|
||||
to be a maximum of 4GB. This param allows developers to specify the
|
||||
file size of the splits.
|
||||
:type approx_max_file_size_bytes: long
|
||||
:param cassandra_conn_id: Reference to a specific Cassandra hook.
|
||||
:type cassandra_conn_id: string
|
||||
:param google_cloud_storage_conn_id: Reference to a specific Google
|
||||
cloud storage hook.
|
||||
:type google_cloud_storage_conn_id: string
|
||||
:param delegate_to: The account to impersonate, if any. For this to
|
||||
work, the service account making the request must have domain-wide
|
||||
delegation enabled.
|
||||
:type delegate_to: string
|
||||
"""
|
||||
super(CassandraToGoogleCloudStorageOperator, self).__init__(*args, **kwargs)
|
||||
self.cql = cql
|
||||
self.bucket = bucket
|
||||
self.filename = filename
|
||||
self.schema_filename = schema_filename
|
||||
self.approx_max_file_size_bytes = approx_max_file_size_bytes
|
||||
self.cassandra_conn_id = cassandra_conn_id
|
||||
self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
|
||||
self.delegate_to = delegate_to
|
||||
|
||||
# Default Cassandra to BigQuery type mapping
|
||||
CQL_TYPE_MAP = {
|
||||
'BytesType': 'BYTES',
|
||||
'DecimalType': 'FLOAT',
|
||||
'UUIDType': 'STRING',
|
||||
'BooleanType': 'BOOL',
|
||||
'ByteType': 'INTEGER',
|
||||
'AsciiType': 'STRING',
|
||||
'FloatType': 'FLOAT',
|
||||
'DoubleType': 'FLOAT',
|
||||
'LongType': 'INTEGER',
|
||||
'Int32Type': 'INTEGER',
|
||||
'IntegerType': 'INTEGER',
|
||||
'InetAddressType': 'STRING',
|
||||
'CounterColumnType': 'INTEGER',
|
||||
'DateType': 'TIMESTAMP',
|
||||
'SimpleDateType': 'DATE',
|
||||
'TimestampType': 'TIMESTAMP',
|
||||
'TimeUUIDType': 'BYTES',
|
||||
'ShortType': 'INTEGER',
|
||||
'TimeType': 'TIME',
|
||||
'DurationType': 'INTEGER',
|
||||
'UTF8Type': 'STRING',
|
||||
'VarcharType': 'STRING',
|
||||
}
|
||||
|
||||
def execute(self, context):
|
||||
cursor = self._query_cassandra()
|
||||
files_to_upload = self._write_local_data_files(cursor)
|
||||
|
||||
# If a schema is set, create a BQ schema JSON file.
|
||||
if self.schema_filename:
|
||||
files_to_upload.update(self._write_local_schema_file(cursor))
|
||||
|
||||
# Flush all files before uploading
|
||||
for file_handle in files_to_upload.values():
|
||||
file_handle.flush()
|
||||
|
||||
self._upload_to_gcs(files_to_upload)
|
||||
|
||||
# Close all temp file handles.
|
||||
for file_handle in files_to_upload.values():
|
||||
file_handle.close()
|
||||
|
||||
def _query_cassandra(self):
|
||||
"""
|
||||
Queries cassandra and returns a cursor to the results.
|
||||
"""
|
||||
hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id)
|
||||
session = hook.get_conn()
|
||||
cursor = session.execute(self.cql)
|
||||
return cursor
|
||||
|
||||
def _write_local_data_files(self, cursor):
|
||||
"""
|
||||
Takes a cursor, and writes results to a local file.
|
||||
|
||||
:return: A dictionary where keys are filenames to be used as object
|
||||
names in GCS, and values are file handles to local files that
|
||||
contain the data for the GCS objects.
|
||||
"""
|
||||
file_no = 0
|
||||
tmp_file_handle = NamedTemporaryFile(delete=True)
|
||||
tmp_file_handles = {self.filename.format(file_no): tmp_file_handle}
|
||||
for row in cursor:
|
||||
row_dict = self.generate_data_dict(row._fields, row)
|
||||
s = json.dumps(row_dict)
|
||||
if PY3:
|
||||
s = s.encode('utf-8')
|
||||
tmp_file_handle.write(s)
|
||||
|
||||
# Append newline to make dumps BigQuery compatible.
|
||||
tmp_file_handle.write(b'\n')
|
||||
|
||||
if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
|
||||
file_no += 1
|
||||
tmp_file_handle = NamedTemporaryFile(delete=True)
|
||||
tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle
|
||||
|
||||
return tmp_file_handles
|
||||
|
||||
def _write_local_schema_file(self, cursor):
|
||||
"""
|
||||
Takes a cursor, and writes the BigQuery schema for the results to a
|
||||
local file system.
|
||||
|
||||
:return: A dictionary where key is a filename to be used as an object
|
||||
name in GCS, and values are file handles to local files that
|
||||
contains the BigQuery schema fields in .json format.
|
||||
"""
|
||||
schema = []
|
||||
tmp_schema_file_handle = NamedTemporaryFile(delete=True)
|
||||
|
||||
for name, type in zip(cursor.column_names, cursor.column_types):
|
||||
schema.append(self.generate_schema_dict(name, type))
|
||||
json_serialized_schema = json.dumps(schema)
|
||||
if PY3:
|
||||
json_serialized_schema = json_serialized_schema.encode('utf-8')
|
||||
|
||||
tmp_schema_file_handle.write(json_serialized_schema)
|
||||
return {self.schema_filename: tmp_schema_file_handle}
|
||||
|
||||
def _upload_to_gcs(self, files_to_upload):
|
||||
hook = GoogleCloudStorageHook(
|
||||
google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
|
||||
delegate_to=self.delegate_to)
|
||||
for object, tmp_file_handle in files_to_upload.items():
|
||||
hook.upload(self.bucket, object, tmp_file_handle.name, 'application/json')
|
||||
|
||||
@classmethod
|
||||
def generate_data_dict(cls, names, values):
|
||||
row_dict = {}
|
||||
for name, value in zip(names, values):
|
||||
row_dict.update({name: cls.convert_value(name, value)})
|
||||
return row_dict
|
||||
|
||||
@classmethod
|
||||
def convert_value(cls, name, value):
|
||||
if not value:
|
||||
return value
|
||||
elif isinstance(value, (text_type, int, float, bool, dict)):
|
||||
return value
|
||||
elif isinstance(value, binary_type):
|
||||
encoded_value = b64encode(value)
|
||||
if PY3:
|
||||
encoded_value = encoded_value.decode('ascii')
|
||||
return encoded_value
|
||||
elif isinstance(value, (datetime, Date, UUID)):
|
||||
return str(value)
|
||||
elif isinstance(value, Decimal):
|
||||
return float(value)
|
||||
elif isinstance(value, Time):
|
||||
return str(value).split('.')[0]
|
||||
elif isinstance(value, (list, SortedSet)):
|
||||
return cls.convert_array_types(name, value)
|
||||
elif hasattr(value, '_fields'):
|
||||
return cls.convert_user_type(name, value)
|
||||
elif isinstance(value, tuple):
|
||||
return cls.convert_tuple_type(name, value)
|
||||
elif isinstance(value, OrderedMapSerializedKey):
|
||||
return cls.convert_map_type(name, value)
|
||||
else:
|
||||
raise AirflowException('unexpected value: ' + str(value))
|
||||
|
||||
@classmethod
|
||||
def convert_array_types(cls, name, value):
|
||||
return [cls.convert_value(name, nested_value) for nested_value in value]
|
||||
|
||||
@classmethod
|
||||
def convert_user_type(cls, name, value):
|
||||
"""
|
||||
Converts a user type to RECORD that contains n fields, where n is the
|
||||
number of attributes. Each element in the user type class will be converted to its
|
||||
corresponding data type in BQ.
|
||||
"""
|
||||
names = value._fields
|
||||
values = [cls.convert_value(name, getattr(value, name)) for name in names]
|
||||
return cls.generate_data_dict(names, values)
|
||||
|
||||
@classmethod
|
||||
def convert_tuple_type(cls, name, value):
|
||||
"""
|
||||
Converts a tuple to RECORD that contains n fields, each will be converted
|
||||
to its corresponding data type in bq and will be named 'field_<index>', where
|
||||
index is determined by the order of the tuple elments defined in cassandra.
|
||||
"""
|
||||
names = ['field_' + str(i) for i in range(len(value))]
|
||||
values = [cls.convert_value(name, value) for name, value in zip(names, value)]
|
||||
return cls.generate_data_dict(names, values)
|
||||
|
||||
@classmethod
|
||||
def convert_map_type(cls, name, value):
|
||||
"""
|
||||
Converts a map to a repeated RECORD that contains two fields: 'key' and 'value',
|
||||
each will be converted to its corresopnding data type in BQ.
|
||||
"""
|
||||
converted_map = []
|
||||
for k, v in zip(value.keys(), value.values()):
|
||||
converted_map.append({
|
||||
'key': cls.convert_value('key', k),
|
||||
'value': cls.convert_value('value', v)
|
||||
})
|
||||
return converted_map
|
||||
|
||||
@classmethod
|
||||
def generate_schema_dict(cls, name, type):
|
||||
field_schema = dict()
|
||||
field_schema.update({'name': name})
|
||||
field_schema.update({'type': cls.get_bq_type(type)})
|
||||
field_schema.update({'mode': cls.get_bq_mode(type)})
|
||||
fields = cls.get_bq_fields(name, type)
|
||||
if fields:
|
||||
field_schema.update({'fields': fields})
|
||||
return field_schema
|
||||
|
||||
@classmethod
|
||||
def get_bq_fields(cls, name, type):
|
||||
fields = []
|
||||
|
||||
if not cls.is_simple_type(type):
|
||||
names, types = [], []
|
||||
|
||||
if cls.is_array_type(type) and cls.is_record_type(type.subtypes[0]):
|
||||
names = type.subtypes[0].fieldnames
|
||||
types = type.subtypes[0].subtypes
|
||||
elif cls.is_record_type(type):
|
||||
names = type.fieldnames
|
||||
types = type.subtypes
|
||||
|
||||
if types and not names and type.cassname == 'TupleType':
|
||||
names = ['field_' + str(i) for i in range(len(types))]
|
||||
elif types and not names and type.cassname == 'MapType':
|
||||
names = ['key', 'value']
|
||||
|
||||
for name, type in zip(names, types):
|
||||
field = cls.generate_schema_dict(name, type)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
@classmethod
|
||||
def is_simple_type(cls, type):
|
||||
return type.cassname in CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP
|
||||
|
||||
@classmethod
|
||||
def is_array_type(cls, type):
|
||||
return type.cassname in ['ListType', 'SetType']
|
||||
|
||||
@classmethod
|
||||
def is_record_type(cls, type):
|
||||
return type.cassname in ['UserType', 'TupleType', 'MapType']
|
||||
|
||||
@classmethod
|
||||
def get_bq_type(cls, type):
|
||||
if cls.is_simple_type(type):
|
||||
return CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP[type.cassname]
|
||||
elif cls.is_record_type(type):
|
||||
return 'RECORD'
|
||||
elif cls.is_array_type(type):
|
||||
return cls.get_bq_type(type.subtypes[0])
|
||||
else:
|
||||
raise AirflowException('Not a supported type: ' + type.cassname)
|
||||
|
||||
@classmethod
|
||||
def get_bq_mode(cls, type):
|
||||
if cls.is_array_type(type) or type.cassname == 'MapType':
|
||||
return 'REPEATED'
|
||||
elif cls.is_record_type(type) or cls.is_simple_type(type):
|
||||
return 'NULLABLE'
|
||||
else:
|
||||
raise AirflowException('Not a supported type: ' + type.cassname)
|
|
@ -603,6 +603,7 @@ class Connection(Base, LoggingMixin):
|
|||
('snowflake', 'Snowflake',),
|
||||
('segment', 'Segment',),
|
||||
('azure_data_lake', 'Azure Data Lake'),
|
||||
('cassandra', 'Cassandra',),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
@ -753,6 +754,9 @@ class Connection(Base, LoggingMixin):
|
|||
elif self.conn_type == 'azure_data_lake':
|
||||
from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook
|
||||
return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
|
||||
elif self.conn_type == 'cassandra':
|
||||
from airflow.contrib.hooks.cassandra_hook import CassandraHook
|
||||
return CassandraHook(cassandra_conn_id=self.conn_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
|
@ -276,6 +276,10 @@ def initdb(rbac=False):
|
|||
models.Connection(
|
||||
conn_id='azure_data_lake_default', conn_type='azure_data_lake',
|
||||
extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }'))
|
||||
merge_conn(
|
||||
models.Connection(
|
||||
conn_id='cassandra_default', conn_type='cassandra',
|
||||
host='localhost', port=9042))
|
||||
|
||||
# Known event types
|
||||
KET = models.KnownEventType
|
||||
|
|
|
@ -121,6 +121,7 @@ Operators
|
|||
.. autoclass:: airflow.contrib.operators.bigquery_table_delete_operator.BigQueryTableDeleteOperator
|
||||
.. autoclass:: airflow.contrib.operators.bigquery_to_bigquery.BigQueryToBigQueryOperator
|
||||
.. autoclass:: airflow.contrib.operators.bigquery_to_gcs.BigQueryToCloudStorageOperator
|
||||
.. autoclass:: airflow.contrib.operators.cassandra_to_gcs.CassandraToGoogleCloudStorageOperator
|
||||
.. autoclass:: airflow.contrib.operators.databricks_operator.DatabricksSubmitRunOperator
|
||||
.. autoclass:: airflow.contrib.operators.dataflow_operator.DataFlowJavaOperator
|
||||
.. autoclass:: airflow.contrib.operators.dataflow_operator.DataflowTemplateOperator
|
||||
|
@ -354,6 +355,7 @@ Community contributed hooks
|
|||
.. autoclass:: airflow.contrib.hooks.aws_hook.AwsHook
|
||||
.. autoclass:: airflow.contrib.hooks.aws_lambda_hook.AwsLambdaHook
|
||||
.. autoclass:: airflow.contrib.hooks.bigquery_hook.BigQueryHook
|
||||
.. autoclass:: airflow.contrib.hooks.cassandra_hook.CassandraHook
|
||||
.. autoclass:: airflow.contrib.hooks.cloudant_hook.CloudantHook
|
||||
.. autoclass:: airflow.contrib.hooks.databricks_hook.DatabricksHook
|
||||
.. autoclass:: airflow.contrib.hooks.datadog_hook.DatadogHook
|
||||
|
|
7
setup.py
7
setup.py
|
@ -114,7 +114,7 @@ azure_data_lake = [
|
|||
'azure-mgmt-datalake-store==0.4.0',
|
||||
'azure-datalake-store==0.0.19'
|
||||
]
|
||||
sendgrid = ['sendgrid>=5.2.0']
|
||||
cassandra = ['cassandra-driver>=3.13.0']
|
||||
celery = [
|
||||
'celery>=4.0.2',
|
||||
'flower>=0.7.3'
|
||||
|
@ -184,6 +184,7 @@ s3 = ['boto3>=1.7.0']
|
|||
salesforce = ['simple-salesforce>=0.72']
|
||||
samba = ['pysmbclient>=0.1.3']
|
||||
segment = ['analytics-python>=1.2.9']
|
||||
sendgrid = ['sendgrid>=5.2.0']
|
||||
slack = ['slackclient>=1.0.0']
|
||||
snowflake = ['snowflake-connector-python>=1.5.2',
|
||||
'snowflake-sqlalchemy>=1.1.0']
|
||||
|
@ -194,7 +195,8 @@ webhdfs = ['hdfs[dataframe,avro,kerberos]>=2.0.4']
|
|||
winrm = ['pywinrm==0.2.2']
|
||||
zendesk = ['zdesk']
|
||||
|
||||
all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant + druid + pinot
|
||||
all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant + druid + pinot \
|
||||
+ cassandra
|
||||
devel = [
|
||||
'click',
|
||||
'freezegun',
|
||||
|
@ -290,6 +292,7 @@ def do_setup():
|
|||
'async': async,
|
||||
'azure_blob_storage': azure_blob_storage,
|
||||
'azure_data_lake': azure_data_lake,
|
||||
'cassandra': cassandra,
|
||||
'celery': celery,
|
||||
'cgroups': cgroups,
|
||||
'cloudant': cloudant,
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
import mock
|
||||
|
||||
from airflow import configuration
|
||||
from airflow.contrib.hooks.cassandra_hook import CassandraHook
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.policies import TokenAwarePolicy
|
||||
from airflow import models
|
||||
from airflow.utils import db
|
||||
|
||||
|
||||
class CassandraHookTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
configuration.load_test_config()
|
||||
db.merge_conn(
|
||||
models.Connection(
|
||||
conn_id='cassandra_test', conn_type='cassandra',
|
||||
host='host-1,host-2', port='9042', schema='test_keyspace',
|
||||
extra='{"load_balancing_policy":"TokenAwarePolicy"'))
|
||||
|
||||
def test_get_conn(self):
|
||||
with mock.patch.object(Cluster, "connect") as mock_connect, \
|
||||
mock.patch("socket.getaddrinfo", return_value=[]) as mock_getaddrinfo:
|
||||
mock_connect.return_value = 'session'
|
||||
hook = CassandraHook(cassandra_conn_id='cassandra_test')
|
||||
hook.get_conn()
|
||||
mock_getaddrinfo.assert_called()
|
||||
mock_connect.assert_called_once_with('test_keyspace')
|
||||
|
||||
cluster = hook.get_cluster()
|
||||
self.assertEqual(cluster.contact_points, ['host-1', 'host-2'])
|
||||
self.assertEqual(cluster.port, 9042)
|
||||
self.assertTrue(isinstance(cluster.load_balancing_policy, TokenAwarePolicy))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,92 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
import mock
|
||||
from builtins import str
|
||||
from airflow.contrib.operators.cassandra_to_gcs import \
|
||||
CassandraToGoogleCloudStorageOperator
|
||||
|
||||
|
||||
class CassandraToGCSTest(unittest.TestCase):
|
||||
|
||||
@mock.patch('airflow.contrib.operators.gcs_to_s3.GoogleCloudStorageHook.upload')
|
||||
@mock.patch('airflow.contrib.hooks.cassandra_hook.CassandraHook.get_conn')
|
||||
def test_execute(self, upload, get_conn):
|
||||
operator = CassandraToGoogleCloudStorageOperator(
|
||||
task_id='test-cas-to-gcs',
|
||||
cql='select * from keyspace1.table1',
|
||||
bucket='test-bucket',
|
||||
filename='data.json',
|
||||
schema_filename='schema.json')
|
||||
|
||||
operator.execute(None)
|
||||
|
||||
self.assertTrue(get_conn.called_once())
|
||||
self.assertTrue(upload.called_once())
|
||||
|
||||
def test_convert_value(self):
|
||||
op = CassandraToGoogleCloudStorageOperator
|
||||
self.assertEquals(op.convert_value('None', None), None)
|
||||
self.assertEquals(op.convert_value('int', 1), 1)
|
||||
self.assertEquals(op.convert_value('float', 1.0), 1.0)
|
||||
self.assertEquals(op.convert_value('str', "text"), "text")
|
||||
self.assertEquals(op.convert_value('bool', True), True)
|
||||
self.assertEquals(op.convert_value('dict', {"a": "b"}), {"a": "b"})
|
||||
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
self.assertEquals(op.convert_value('datetime', now), str(now))
|
||||
|
||||
from cassandra.util import Date
|
||||
date_str = '2018-01-01'
|
||||
date = Date(date_str)
|
||||
self.assertEquals(op.convert_value('date', date), str(date_str))
|
||||
|
||||
import uuid
|
||||
test_uuid = uuid.uuid4()
|
||||
self.assertEquals(op.convert_value('uuid', test_uuid), str(test_uuid))
|
||||
|
||||
from decimal import Decimal
|
||||
d = Decimal(1.0)
|
||||
self.assertEquals(op.convert_value('decimal', d), float(d))
|
||||
|
||||
from base64 import b64encode
|
||||
b = b'abc'
|
||||
encoded_b = b64encode(b).decode('ascii')
|
||||
self.assertEquals(op.convert_value('binary', b), encoded_b)
|
||||
|
||||
from cassandra.util import Time
|
||||
time = Time(0)
|
||||
self.assertEquals(op.convert_value('time', time), '00:00:00')
|
||||
|
||||
date_str_lst = ['2018-01-01', '2018-01-02', '2018-01-03']
|
||||
date_lst = [Date(d) for d in date_str_lst]
|
||||
self.assertEquals(op.convert_value('list', date_lst), date_str_lst)
|
||||
|
||||
date_tpl = tuple(date_lst)
|
||||
self.assertEquals(op.convert_value('tuple', date_tpl),
|
||||
{'field_0': '2018-01-01',
|
||||
'field_1': '2018-01-02',
|
||||
'field_2': '2018-01-03', })
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче