[AIRFLOW-2333] Add Segment Hook and TrackEventOperator
Add support for Segment with an accompanying hook and an operator for sending track events Closes #3335 from jzucker2/add-segment-support
This commit is contained in:
Родитель
250faad0f5
Коммит
4d43b78f11
|
@ -0,0 +1,92 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
"""
|
||||
This module contains a Segment Hook
|
||||
which allows you to connect to your Segment account,
|
||||
retrieve data from it or write to that file.
|
||||
|
||||
NOTE: this hook also relies on the Segment analytics package:
|
||||
https://github.com/segmentio/analytics-python
|
||||
"""
|
||||
import analytics
|
||||
from airflow.hooks.base_hook import BaseHook
|
||||
from airflow.exceptions import AirflowException
|
||||
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
|
||||
|
||||
class SegmentHook(BaseHook, LoggingMixin):
|
||||
def __init__(
|
||||
self,
|
||||
segment_conn_id='segment_default',
|
||||
segment_debug_mode=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Create new connection to Segment
|
||||
and allows you to pull data out of Segment or write to it.
|
||||
|
||||
You can then use that file with other
|
||||
Airflow operators to move the data around or interact with segment.
|
||||
|
||||
:param segment_conn_id: the name of the connection that has the parameters
|
||||
we need to connect to Segment.
|
||||
The connection should be type `json` and include a
|
||||
write_key security token in the `Extras` field.
|
||||
:type segment_conn_id: str
|
||||
:param segment_debug_mode: Determines whether Segment should run in debug mode.
|
||||
Defaults to False
|
||||
:type segment_debug_mode: boolean
|
||||
.. note::
|
||||
You must include a JSON structure in the `Extras` field.
|
||||
We need a user's security token to connect to Segment.
|
||||
So we define it in the `Extras` field as:
|
||||
`{"write_key":"YOUR_SECURITY_TOKEN"}`
|
||||
"""
|
||||
self.segment_conn_id = segment_conn_id
|
||||
self.segment_debug_mode = segment_debug_mode
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
||||
# get the connection parameters
|
||||
self.connection = self.get_connection(self.segment_conn_id)
|
||||
self.extras = self.connection.extra_dejson
|
||||
self.write_key = self.extras.get('write_key')
|
||||
if self.write_key is None:
|
||||
raise AirflowException('No Segment write key provided')
|
||||
|
||||
def get_conn(self):
|
||||
self.log.info('Setting write key for Segment analytics connection')
|
||||
analytics.debug = self.segment_debug_mode
|
||||
if self.segment_debug_mode:
|
||||
self.log.info('Setting Segment analytics connection to debug mode')
|
||||
analytics.on_error = self.on_error
|
||||
analytics.write_key = self.write_key
|
||||
return analytics
|
||||
|
||||
def on_error(self, error, items):
|
||||
"""
|
||||
Handles error callbacks when using Segment with segment_debug_mode set to True
|
||||
"""
|
||||
self.log.error('Encountered Segment error: {segment_error} with '
|
||||
'items: {with_items}'.format(segment_error=error,
|
||||
with_items=items))
|
||||
raise AirflowException('Segment error: {}'.format(error))
|
|
@ -0,0 +1,69 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow.contrib.hooks.segment_hook import SegmentHook
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SegmentTrackEventOperator(BaseOperator):
|
||||
"""
|
||||
Send Track Event to Segment for a specified user_id and event
|
||||
|
||||
:param user_id: The ID for this user in your database
|
||||
:type user_id: string
|
||||
:param event: The name of the event you're tracking
|
||||
:type event: string
|
||||
:param properties: A dictionary of properties for the event.
|
||||
:type properties: dict
|
||||
:param segment_conn_id: The connection ID to use when connecting to Segment.
|
||||
:type segment_conn_id: string
|
||||
:param segment_debug_mode: Determines whether Segment should run in debug mode.
|
||||
Defaults to False
|
||||
:type segment_debug_mode: boolean
|
||||
"""
|
||||
template_fields = ('user_id', 'event', 'properties')
|
||||
ui_color = '#ffd700'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self,
|
||||
user_id,
|
||||
event,
|
||||
properties=None,
|
||||
segment_conn_id='segment_default',
|
||||
segment_debug_mode=False,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(SegmentTrackEventOperator, self).__init__(*args, **kwargs)
|
||||
self.user_id = user_id
|
||||
self.event = event
|
||||
properties = properties or {}
|
||||
self.properties = properties
|
||||
self.segment_debug_mode = segment_debug_mode
|
||||
self.segment_conn_id = segment_conn_id
|
||||
|
||||
def execute(self, context):
|
||||
hook = SegmentHook(segment_conn_id=self.segment_conn_id,
|
||||
segment_debug_mode=self.segment_debug_mode)
|
||||
|
||||
self.log.info(
|
||||
'Sending track event ({0}) for user id: {1} with properties: {2}'.
|
||||
format(self.event, self.user_id, self.properties))
|
||||
|
||||
hook.track(self.user_id, self.event, self.properties)
|
|
@ -600,6 +600,7 @@ class Connection(Base, LoggingMixin):
|
|||
('aws', 'Amazon Web Services',),
|
||||
('emr', 'Elastic MapReduce',),
|
||||
('snowflake', 'Snowflake',),
|
||||
('segment', 'Segment',),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -7,9 +7,9 @@
|
|||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
|
@ -268,6 +268,10 @@ def initdb(rbac=False):
|
|||
models.Connection(
|
||||
conn_id='qubole_default', conn_type='qubole',
|
||||
host= 'localhost'))
|
||||
merge_conn(
|
||||
models.Connection(
|
||||
conn_id='segment_default', conn_type='segment',
|
||||
extra='{"write_key": "my-segment-write-key"}'))
|
||||
|
||||
# Known event types
|
||||
KET = models.KnownEventType
|
||||
|
|
|
@ -172,6 +172,7 @@ Operators
|
|||
.. autoclass:: airflow.contrib.operators.qubole_operator.QuboleOperator
|
||||
.. autoclass:: airflow.contrib.operators.s3_list_operator.S3ListOperator
|
||||
.. autoclass:: airflow.contrib.operators.s3_to_gcs_operator.S3ToGoogleCloudStorageOperator
|
||||
.. autoclass:: airflow.operators.segment_track_event_operator.SegmentTrackEventOperator
|
||||
.. autoclass:: airflow.contrib.operators.sftp_operator.SFTPOperator
|
||||
.. autoclass:: airflow.contrib.operators.slack_webhook_operator.SlackWebhookOperator
|
||||
.. autoclass:: airflow.contrib.operators.snowflake_operator.SnowflakeOperator
|
||||
|
@ -372,6 +373,7 @@ Community contributed hooks
|
|||
.. autoclass:: airflow.contrib.hooks.redis_hook.RedisHook
|
||||
.. autoclass:: airflow.contrib.hooks.redshift_hook.RedshiftHook
|
||||
.. autoclass:: airflow.contrib.hooks.salesforce_hook.SalesforceHook
|
||||
.. autoclass:: airflow.contrib.hooks.segment_hook.SegmentHook
|
||||
.. autoclass:: airflow.contrib.hooks.sftp_hook.SFTPHook
|
||||
.. autoclass:: airflow.contrib.hooks.slack_webhook_hook.SlackWebhookHook
|
||||
.. autoclass:: airflow.contrib.hooks.snowflake_hook.SnowflakeHook
|
||||
|
|
4
setup.py
4
setup.py
|
@ -177,6 +177,7 @@ redis = ['redis>=2.10.5']
|
|||
s3 = ['boto3>=1.7.0']
|
||||
salesforce = ['simple-salesforce>=0.72']
|
||||
samba = ['pysmbclient>=0.1.3']
|
||||
segment = ['analytics-python>=1.2.9']
|
||||
slack = ['slackclient>=1.0.0']
|
||||
snowflake = ['snowflake-connector-python>=1.5.2',
|
||||
'snowflake-sqlalchemy>=1.1.0']
|
||||
|
@ -211,7 +212,7 @@ devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos
|
|||
devel_all = (sendgrid + devel + all_dbs + doc + samba + s3 + slack + crypto + oracle +
|
||||
docker + ssh + kubernetes + celery + azure + redis + gcp_api + datadog +
|
||||
zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins +
|
||||
druid + pinot + snowflake + elasticsearch)
|
||||
druid + pinot + segment + snowflake + elasticsearch)
|
||||
|
||||
# Snakebite & Google Cloud Dataflow are not Python 3 compatible :'(
|
||||
if PY3:
|
||||
|
@ -316,6 +317,7 @@ def do_setup():
|
|||
'salesforce': salesforce,
|
||||
'samba': samba,
|
||||
'sendgrid': sendgrid,
|
||||
'segment': segment,
|
||||
'slack': slack,
|
||||
'snowflake': snowflake,
|
||||
'ssh': ssh,
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
import mock
|
||||
import unittest
|
||||
|
||||
from airflow import configuration, AirflowException
|
||||
|
||||
from airflow.contrib.hooks.segment_hook import SegmentHook
|
||||
|
||||
TEST_CONN_ID = 'test_segment'
|
||||
WRITE_KEY = 'foo'
|
||||
|
||||
|
||||
class TestSegmentHook(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestSegmentHook, self).setUp()
|
||||
configuration.load_test_config()
|
||||
|
||||
self.conn = conn = mock.MagicMock()
|
||||
conn.write_key = WRITE_KEY
|
||||
self.expected_write_key = WRITE_KEY
|
||||
self.conn.extra_dejson = {'write_key': self.expected_write_key}
|
||||
|
||||
class UnitTestSegmentHook(SegmentHook):
|
||||
|
||||
def get_conn(self):
|
||||
return conn
|
||||
|
||||
def get_connection(self, connection_id):
|
||||
return conn
|
||||
|
||||
self.test_hook = UnitTestSegmentHook(segment_conn_id=TEST_CONN_ID)
|
||||
|
||||
def test_get_conn(self):
|
||||
expected_connection = self.test_hook.get_conn()
|
||||
self.assertEqual(expected_connection, self.conn)
|
||||
self.assertIsNotNone(expected_connection.write_key)
|
||||
self.assertEqual(expected_connection.write_key, self.expected_write_key)
|
||||
|
||||
def test_on_error(self):
|
||||
with self.assertRaises(AirflowException):
|
||||
self.test_hook.on_error('error', ['items'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,64 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
import mock
|
||||
import unittest
|
||||
|
||||
from airflow import configuration, AirflowException
|
||||
|
||||
from airflow.contrib.hooks.segment_hook import SegmentHook
|
||||
|
||||
TEST_CONN_ID = 'test_segment'
|
||||
WRITE_KEY = 'foo'
|
||||
|
||||
|
||||
class TestSegmentHook(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestSegmentHook, self).setUp()
|
||||
configuration.load_test_config()
|
||||
|
||||
self.conn = conn = mock.MagicMock()
|
||||
conn.write_key = WRITE_KEY
|
||||
self.expected_write_key = WRITE_KEY
|
||||
self.conn.extra_dejson = {'write_key': self.expected_write_key}
|
||||
|
||||
class UnitTestSegmentHook(SegmentHook):
|
||||
|
||||
def get_conn(self):
|
||||
return conn
|
||||
|
||||
def get_connection(self, connection_id):
|
||||
return conn
|
||||
|
||||
self.test_hook = UnitTestSegmentHook(segment_conn_id=TEST_CONN_ID)
|
||||
|
||||
def test_get_conn(self):
|
||||
expected_connection = self.test_hook.get_conn()
|
||||
self.assertEqual(expected_connection, self.conn)
|
||||
self.assertIsNotNone(expected_connection.write_key)
|
||||
self.assertEqual(expected_connection.write_key, self.expected_write_key)
|
||||
|
||||
def test_on_error(self):
|
||||
with self.assertRaises(AirflowException):
|
||||
self.test_hook.on_error('error', ['items'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1038,6 +1038,7 @@ class CliTests(unittest.TestCase):
|
|||
self.assertIn(['mysql_default', 'mysql'], conns)
|
||||
self.assertIn(['postgres_default', 'postgres'], conns)
|
||||
self.assertIn(['wasb_default', 'wasb'], conns)
|
||||
self.assertIn(['segment_default', 'segment'], conns)
|
||||
|
||||
# Attempt to list connections with invalid cli args
|
||||
with mock.patch('sys.stdout',
|
||||
|
|
Загрузка…
Ссылка в новой задаче