Making login module generic / overridable

This commit is contained in:
Maxime Beauchemin 2015-03-14 16:01:11 -07:00 коммит произвёл Maxime
Родитель e3521c17fc
Коммит 4518dcb326
10 изменённых файлов: 359 добавлений и 258 удалений

2
.gitignore поставляемый
Просмотреть файл

@ -1,4 +1,5 @@
*.pyc
airflow/www/static/coverage/
.ipynb*
docs/_*
airflow.db
@ -16,3 +17,4 @@ secrets.py
*.egg-info
*.bkp
.DS_Store
airflow_login.py

13
TODO.md
Просмотреть файл

@ -4,31 +4,22 @@ TODO
* Backfill wizard
#### unittests
* Increase coverage, now 80ish%
* Increase coverage, now 85ish%
#### Command line
* `airflow task_state dag_id task_id YYYY-MM-DD`
#### More Operators!
* Sandbox the BashOperator
* S3Sensor
* BaseDataTransferOperator
* File2MySqlOperator
* DagTaskSensor for cross dag dependencies
* PIG
#### Frontend
*
#### Backend
* Make authentication universal
* Callbacks
* Master auto dag refresh at time intervals
* Prevent timezone chagne on import
* Add decorator to timeout imports on master process [lib](https://github.com/pnpnpn/timeout-decorator)
* Make authentication universal
#### Misc
* Write an hypervisor, looks for dead jobs without a heartbeat and kills
#### Wishlist
* Support for cron like synthax (0 * * * ) using croniter library

Просмотреть файл

@ -1,3 +1,18 @@
from models import DAG
__version__ = "0.4.3"
'''
Authentication is implemented using flask_login and different environments can
implement their own login mechanisms by providing an `airflow_login` module
in their PYTHONPATH. airflow_login should be based off the
`airflow.www.login`
'''
try:
# Environment specific login
import airflow_login
login = airflow_login
except ImportError:
# Default login, no real authentication
from airflow import default_login
login = default_login
from models import DAG

72
airflow/default_login.py Normal file
Просмотреть файл

@ -0,0 +1,72 @@
'''
Override this file to handle your authenticatin / login.
Copy and alter this file and put in your PYTHONPATH as airflow_login.py,
the new module will override this one.
'''
import flask_login
from flask_login import login_required, current_user, logout_user
from flask import url_for, redirect
from airflow import settings
from airflow import models
DEFAULT_USERNAME = 'airflow'
login_manager = flask_login.LoginManager()
login_manager.login_view = 'airflow.login' # Calls login() bellow
login_manager.login_message = None
class User(models.BaseUser):
def is_active(self):
'''Required by flask_login'''
return True
def is_authenticated(self):
'''Required by flask_login'''
return True
def is_anonymous(self):
'''Required by flask_login'''
return False
def data_profiling(self):
'''Provides access to data profiling tools'''
return True
def is_superuser(self):
'''Access all the things'''
return True
models.User = User # hack!
@login_manager.user_loader
def load_user(userid):
session = settings.Session()
user = session.query(User).filter(User.id == userid).first()
session.expunge_all()
session.commit()
session.close()
return user
def login(self, request):
session = settings.Session()
user = session.query(User).filter(
User.username == DEFAULT_USERNAME).first()
if not user:
user = User(
username=DEFAULT_USERNAME,
has_access=True,
is_superuser=True)
session.merge(user)
session.expunge_all()
session.commit()
session.close()
flask_login.login_user(user)
return redirect(request.args.get("next") or url_for("index"))

Просмотреть файл

@ -183,34 +183,19 @@ class DagBag(object):
return dag_ids
class User(Base):
"""
Eventually should be used for security purposes
"""
class BaseUser(Base):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
username = Column(String(ID_LEN), unique=True)
email = Column(String(500))
def __init__(self, username=None, email=None):
self.username = username
self.email = email
def __repr__(self):
return self.username
def get_id(self):
return unicode(self.id)
def is_active(self):
return True
def is_authenticated(self):
return True
def is_anonymous(self):
return False
class Connection(Base):
"""
@ -1507,6 +1492,9 @@ class Chart(Base):
iteration_no = Column(Integer, default=0)
last_modified = Column(DateTime, default=datetime.now())
def __repr__(self):
return self.label
class KnownEventType(Base):
__tablename__ = "known_event_type"
@ -1534,3 +1522,6 @@ class KnownEvent(Base):
cascade=False,
cascade_backrefs=False, backref='known_events')
description = Column(Text)
def __repr__(self):
return self.label

Просмотреть файл

@ -6,6 +6,7 @@ from airflow.utils import apply_defaults
class EmailOperator(BaseOperator):
template_fields = ('subject', 'html_content')
template_ext = ('.html',)
ui_color = '#e6faf9'
__mapper_args__ = {

Просмотреть файл

@ -4,10 +4,8 @@ import dateutil.parser
import json
import logging
import os
import re
import socket
import sys
import urllib2
from flask import Flask, url_for, Markup, Blueprint, redirect, flash, Response
from flask.ext.admin import Admin, BaseView, expose, AdminIndexView
@ -28,18 +26,58 @@ import markdown
import chartkick
import airflow
login_required = airflow.login.login_required
current_user = airflow.login.current_user
logout_user = airflow.login.logout_user
from airflow.settings import Session
from airflow import jobs
from airflow import models
from airflow import login
from airflow.models import State
from airflow import settings
from airflow.configuration import conf
from airflow import utils
from airflow.www import utils as wwwutils
from airflow.www.login import login_manager
import flask_login
from flask_login import login_required
from functools import wraps
AUTHENTICATE = conf.getboolean('master', 'AUTHENTICATE')
if AUTHENTICATE is False:
login_required = lambda x: x
def superuser_required(f):
'''
Decorator for views requiring superuser access
'''
@wraps(f)
def decorated_function(*args, **kwargs):
if (
not AUTHENTICATE or
(not current_user.is_anonymous() and current_user.is_superuser())
):
return f(*args, **kwargs)
else:
flash("This page requires superuser priviledges", "error")
return redirect(url_for('admin.index'))
return decorated_function
def data_profiling_required(f):
'''
Decorator for views requiring data profiling access
'''
@wraps(f)
def decorated_function(*args, **kwargs):
if (
not AUTHENTICATE or
(not current_user.is_anonymous() and current_user.data_profiling())
):
return f(*args, **kwargs)
else:
flash("This page requires data profiling priviledges", "error")
return redirect(url_for('admin.index'))
return decorated_function
QUERY_LIMIT = 100000
CHART_LIMIT = 200000
@ -50,9 +88,6 @@ special_attrs = {
'bash_command': BashLexer,
}
AUTHENTICATE = conf.getboolean('master', 'AUTHENTICATE')
if AUTHENTICATE is False:
login_required = lambda x: x
dagbag = models.DagBag(os.path.expanduser(conf.get('core', 'DAGS_FOLDER')))
utils.pessimistic_connection_handling()
@ -60,7 +95,7 @@ utils.pessimistic_connection_handling()
app = Flask(__name__)
app.config['SQLALCHEMY_POOL_RECYCLE'] = 3600
login_manager.init_app(app)
login.login_manager.init_app(app)
app.secret_key = 'airflowified'
cache = Cache(
@ -113,6 +148,7 @@ class HomeView(AdminIndexView):
Basic home view, just showing the README.md file
"""
@expose("/")
@login_required
def index(self):
dags = dagbag.dags.values()
dags = [dag for dag in dags if not dag.parent_dag]
@ -125,12 +161,6 @@ admin = Admin(
index_view=HomeView(name='DAGs'),
template_mode='bootstrap3')
admin.add_link(
base.MenuLink(
category='Data Profiling',
name='Ad Hoc Query',
url='/admin/airflow/query'))
class Airflow(BaseView):
@ -138,73 +168,12 @@ class Airflow(BaseView):
return False
@expose('/')
@login_required
def index(self):
return self.render('airflow/dags.html')
@expose('/query')
@login_required
@wwwutils.gzipped
def query(self):
session = settings.Session()
dbs = session.query(models.Connection).order_by(
models.Connection.conn_id)
db_choices = [(db.conn_id, db.conn_id) for db in dbs if db.get_hook()]
conn_id_str = request.args.get('conn_id')
csv = request.args.get('csv') == "true"
sql = request.args.get('sql')
class QueryForm(Form):
conn_id = SelectField("Layout", choices=db_choices)
sql = TextAreaField("SQL", widget=wwwutils.AceEditorWidget())
data = {
'conn_id': conn_id_str,
'sql': sql,
}
results = None
has_data = False
error = False
if conn_id_str:
db = [db for db in dbs if db.conn_id == conn_id_str][0]
hook = db.get_hook()
try:
df = hook.get_pandas_df(wwwutils.limit_sql(sql, QUERY_LIMIT))
# df = hook.get_pandas_df(sql)
has_data = len(df) > 0
df = df.fillna('')
results = df.to_html(
classes="table table-bordered table-striped no-wrap",
index=False,
na_rep='',
) if has_data else ''
except Exception as e:
flash(str(e), 'error')
error = True
if has_data and len(df) == QUERY_LIMIT:
flash(
"Query output truncated at " + str(QUERY_LIMIT) +
" rows", 'info')
if not has_data and error:
flash('No data', 'error')
if csv:
return Response(
response=df.to_csv(index=False),
status=200,
mimetype="application/text")
form = QueryForm(request.form, data=data)
session.commit()
session.close()
return self.render(
'airflow/query.html', form=form,
title="Ad Hoc Query",
results=results or '',
has_data=has_data)
@expose('/chart_data')
@login_required
@data_profiling_required
@wwwutils.gzipped
@cache.cached(timeout=3600, key_prefix=wwwutils.make_cache_key)
def chart_data(self):
@ -472,7 +441,7 @@ class Airflow(BaseView):
mimetype="application/json")
@expose('/chart')
@login_required
@data_profiling_required
def chart(self):
session = settings.Session()
chart_id = request.args.get('chart_id')
@ -497,6 +466,7 @@ class Airflow(BaseView):
label=chart.label)
@expose('/dag_stats')
@login_required
def dag_stats(self):
states = [State.SUCCESS, State.RUNNING, State.FAILED]
task_ids = []
@ -537,6 +507,7 @@ class Airflow(BaseView):
status=200, mimetype="application/json")
@expose('/code')
@login_required
def code(self):
dag_id = request.args.get('dag_id')
dag = dagbag.dags[dag_id]
@ -548,6 +519,7 @@ class Airflow(BaseView):
'airflow/dag_code.html', html_code=html_code, dag=dag, title=title)
@expose('/circles')
@login_required
def circles(self):
return self.render(
'airflow/circles.html')
@ -570,33 +542,6 @@ class Airflow(BaseView):
'airflow/code.html',
code_html=code_html, title=title, subtitle=cfg_loc)
@expose('/conf')
@login_required
def conf(self):
from airflow import configuration
raw = request.args.get('raw') == "true"
title = "Airflow Configuration"
subtitle = configuration.AIRFLOW_CONFIG
f = open(configuration.AIRFLOW_CONFIG, 'r')
config = f.read()
f.close()
if raw:
return Response(
response=config,
status=200,
mimetype="application/text")
else:
code_html = Markup(highlight(
config,
IniLexer(), # Lexer call
HtmlFormatter(noclasses=True))
)
return self.render(
'airflow/code.html',
pre_subtitle=settings.HEADER + " v" + airflow.__version__,
code_html=code_html, title=title, subtitle=subtitle)
@expose('/noaccess')
def noaccess(self):
return self.render('airflow/noaccess.html')
@ -630,53 +575,25 @@ class Airflow(BaseView):
@expose('/headers')
def headers(self):
d = {k: v for k, v in request.headers}
d['is_superuser'] = current_user.is_superuser()
d['data_profiling'] = current_user.data_profiling()
d['is_anonymous'] = current_user.is_anonymous()
d['is_authenticated'] = current_user.is_authenticated()
return Response(
response=json.dumps(d, indent=4),
status=200, mimetype="application/json")
@expose('/login')
def login(self):
session = settings.Session()
roles = [
'airpal_topsecret.engineering.airbnb.com',
'hadoop_user.engineering.airbnb.com',
'analytics.engineering.airbnb.com',
'nerds.engineering.airbnb.com',
]
if 'X-Internalauth-Username' not in request.headers:
return redirect(url_for('airflow.noaccess'))
username = request.headers.get('X-Internalauth-Username')
groups = request.headers.get(
'X-Internalauth-Groups').lower().split(',')
has_access = any([g in roles for g in groups])
d = {k: v for k, v in request.headers}
cookie = urllib2.unquote(d.get('Cookie', ''))
mailsrch = re.compile(
r'[\w\-][\w\-\.]+@[\w\-][\w\-\.]+[a-zA-Z]{1,4}')
res = mailsrch.findall(cookie)
email = res[0] if res else ''
if has_access:
user = session.query(models.User).filter(
models.User.username == username).first()
if not user:
user = models.User(username=username)
user.email = email
session.merge(user)
session.commit()
flask_login.login_user(user)
session.commit()
session.close()
return redirect(request.args.get("next") or url_for("index"))
return redirect('/')
return login.login(self, request)
@expose('/logout')
def logout(self):
flask_login.logout_user()
logout_user()
return redirect('/admin')
@expose('/rendered')
@login_required
def rendered(self):
dag_id = request.args.get('dag_id')
task_id = request.args.get('task_id')
@ -714,6 +631,7 @@ class Airflow(BaseView):
title=title,)
@expose('/log')
@login_required
def log(self):
BASE_LOG_FOLDER = os.path.expanduser(
conf.get('core', 'BASE_LOG_FOLDER'))
@ -767,6 +685,7 @@ class Airflow(BaseView):
execution_date=execution_date, form=form)
@expose('/task')
@login_required
def task(self):
dag_id = request.args.get('dag_id')
task_id = request.args.get('task_id')
@ -810,6 +729,7 @@ class Airflow(BaseView):
dag=dag, title=title)
@expose('/action')
@login_required
def action(self):
action = request.args.get('action')
dag_id = request.args.get('dag_id')
@ -926,6 +846,7 @@ class Airflow(BaseView):
return response
@expose('/tree')
@login_required
@wwwutils.gzipped
def tree(self):
dag_id = request.args.get('dag_id')
@ -1015,6 +936,7 @@ class Airflow(BaseView):
dag=dag, data=data)
@expose('/graph')
@login_required
def graph(self):
session = settings.Session()
dag_id = request.args.get('dag_id')
@ -1089,6 +1011,7 @@ class Airflow(BaseView):
edges=json.dumps(edges, indent=2),)
@expose('/duration')
@login_required
def duration(self):
session = settings.Session()
dag_id = request.args.get('dag_id')
@ -1120,6 +1043,7 @@ class Airflow(BaseView):
)
@expose('/landing_times')
@login_required
def landing_times(self):
session = settings.Session()
dag_id = request.args.get('dag_id')
@ -1152,6 +1076,7 @@ class Airflow(BaseView):
)
@expose('/gantt')
@login_required
def gantt(self):
session = settings.Session()
@ -1220,18 +1145,96 @@ class Airflow(BaseView):
admin.add_view(Airflow(name='DAGs'))
# ------------------------------------------------
# Leveraging the admin for CRUD and browse on models
# ------------------------------------------------
class LoginMixin(object):
def is_accessible(self):
return AUTHENTICATE is False or \
flask_login.current_user.is_authenticated()
return (
not AUTHENTICATE or
(not current_user.is_anonymous() and current_user.is_authenticated())
)
class ModelViewOnly(ModelView):
class DataProfilingMixin(object):
def is_accessible(self):
return (
not AUTHENTICATE or
(not current_user.is_anonymous() and current_user.data_profiling())
)
class SuperUserMixin(object):
def is_accessible(self):
return (
not AUTHENTICATE or
(not current_user.is_anonymous() and current_user.is_superuser())
)
class QueryView(DataProfilingMixin, BaseView):
@expose('/')
@wwwutils.gzipped
def query(self):
session = settings.Session()
dbs = session.query(models.Connection).order_by(
models.Connection.conn_id)
db_choices = [(db.conn_id, db.conn_id) for db in dbs if db.get_hook()]
conn_id_str = request.args.get('conn_id')
csv = request.args.get('csv') == "true"
sql = request.args.get('sql')
class QueryForm(Form):
conn_id = SelectField("Layout", choices=db_choices)
sql = TextAreaField("SQL", widget=wwwutils.AceEditorWidget())
data = {
'conn_id': conn_id_str,
'sql': sql,
}
results = None
has_data = False
error = False
if conn_id_str:
db = [db for db in dbs if db.conn_id == conn_id_str][0]
hook = db.get_hook()
try:
df = hook.get_pandas_df(wwwutils.limit_sql(sql, QUERY_LIMIT))
# df = hook.get_pandas_df(sql)
has_data = len(df) > 0
df = df.fillna('')
results = df.to_html(
classes="table table-bordered table-striped no-wrap",
index=False,
na_rep='',
) if has_data else ''
except Exception as e:
flash(str(e), 'error')
error = True
if has_data and len(df) == QUERY_LIMIT:
flash(
"Query output truncated at " + str(QUERY_LIMIT) +
" rows", 'info')
if not has_data and error:
flash('No data', 'error')
if csv:
return Response(
response=df.to_csv(index=False),
status=200,
mimetype="application/text")
form = QueryForm(request.form, data=data)
session.commit()
session.close()
return self.render(
'airflow/query.html', form=form,
title="Ad Hoc Query",
results=results or '',
has_data=has_data)
admin.add_view(QueryView(name='Ad Hoc Query', category="Data Profiling"))
class ModelViewOnly(LoginMixin, ModelView):
"""
Modifying the base ModelView class for non edit, browse only operations
"""
@ -1310,14 +1313,7 @@ mv = TaskInstanceModelView(
admin.add_view(mv)
admin.add_link(
base.MenuLink(
category='Admin',
name='Configuration',
url='/admin/airflow/conf'))
class ConnectionModelView(LoginMixin, ModelView):
class ConnectionModelView(SuperUserMixin, ModelView):
column_list = ('conn_id', 'conn_type', 'host', 'port')
form_choices = {
'conn_type': [
@ -1340,13 +1336,13 @@ mv = ConnectionModelView(
admin.add_view(mv)
class UserModelView(LoginMixin, ModelView):
class UserModelView(SuperUserMixin, ModelView):
column_default_sort = 'username'
mv = UserModelView(models.User, Session, name="Users", category="Admin")
admin.add_view(mv)
class ReloadTaskView(BaseView):
class ReloadTaskView(SuperUserMixin, BaseView):
@expose('/')
def index(self):
logging.info("Reloading the dags")
@ -1356,7 +1352,36 @@ class ReloadTaskView(BaseView):
admin.add_view(ReloadTaskView(name='Reload DAGs', category="Admin"))
class DagModelView(ModelView):
class ConfigurationView(SuperUserMixin, BaseView):
@expose('/')
def conf(self):
from airflow import configuration
raw = request.args.get('raw') == "true"
title = "Airflow Configuration"
subtitle = configuration.AIRFLOW_CONFIG
f = open(configuration.AIRFLOW_CONFIG, 'r')
config = f.read()
f.close()
if raw:
return Response(
response=config,
status=200,
mimetype="application/text")
else:
code_html = Markup(highlight(
config,
IniLexer(), # Lexer call
HtmlFormatter(noclasses=True))
)
return self.render(
'airflow/code.html',
pre_subtitle=settings.HEADER + " v" + airflow.__version__,
code_html=code_html, title=title, subtitle=subtitle)
admin.add_view(ConfigurationView(name='Configuration', category="Admin"))
class DagModelView(SuperUserMixin, ModelView):
column_list = ('dag_id', 'is_paused')
column_editable_list = ('is_paused',)
mv = DagModelView(
@ -1375,7 +1400,7 @@ def label_link(v, c, m, p):
return Markup("<a href='{url}'>{m.label}</a>".format(**locals()))
class ChartModelView(LoginMixin, ModelView):
class ChartModelView(DataProfilingMixin, ModelView):
form_columns = (
'label',
'owner',
@ -1461,8 +1486,8 @@ class ChartModelView(LoginMixin, ModelView):
model.iteration_no = 0
else:
model.iteration_no += 1
if AUTHENTICATE and not model.user_id and flask_login.current_user:
model.user_id = flask_login.current_user.id
if AUTHENTICATE and not model.user_id and current_user:
model.user_id = current_user.id
model.last_modified = datetime.now()
mv = ChartModelView(
@ -1482,7 +1507,7 @@ admin.add_link(
url='https://github.com/mistercrunch/Airflow'))
class KnowEventView(LoginMixin, ModelView):
class KnowEventView(DataProfilingMixin, ModelView):
form_columns = (
'label',
'event_type',
@ -1497,7 +1522,7 @@ mv = KnowEventView(
admin.add_view(mv)
class KnowEventTypeView(LoginMixin, ModelView):
class KnowEventTypeView(DataProfilingMixin, ModelView):
pass
'''
mv = KnowEventTypeView(

Просмотреть файл

@ -1,21 +0,0 @@
import flask_login
from airflow.models import User
from airflow import settings
login_manager = flask_login.LoginManager()
@login_manager.user_loader
def load_user(userid):
session = settings.Session()
user = session.query(User).filter(User.id == userid).first()
#if not user:
# raise Exception(userid)
session.expunge_all()
session.commit()
session.close()
return user
login_manager.login_view = 'airflow.login'
login_manager.login_message = None

Просмотреть файл

@ -5,10 +5,10 @@ working towards a production grade environment is a bit more work.
Extra Packages
''''''''''''''
The `airflow` PyPI basic package only installs what's needed to get started.
The ``airflow`` PyPI basic package only installs what's needed to get started.
Subpackages can be installed depending on what will be useful in your
environment. For instance, if you don't need connectivity with Postgres,
you won't have to go through the trouble of install the `postgres-devel` yum
you won't have to go through the trouble of install the ``postgres-devel`` yum
package, or whatever equivalent on the distribution you are using.
Behind the scene, we do conditional imports on operators that require
@ -19,17 +19,17 @@ Here's the list of the subpackages and that they enable:
+-------------+------------------------------------+---------------------------------------+
| subpackage | install command | enables |
+=============+====================================+=======================================+
| mysql | pip install airflow[mysql] | MySQL operators and hook, support as |
| mysql | ``pip install airflow[mysql]`` | MySQL operators and hook, support as |
| | | an Airflow backend |
+-------------+------------------------------------+---------------------------------------+
| postgres | pip install airflow[postgres] | Postgres operators and hook, support |
| postgres | ``pip install airflow[postgres]`` | Postgres operators and hook, support |
| | | as an Airflow backend |
+-------------+------------------------------------+---------------------------------------+
| samba | pip install airflow[samba] | Hive2SambaOperator |
| samba | ``pip install airflow[samba]`` | ``Hive2SambaOperator`` |
+-------------+------------------------------------+---------------------------------------+
| s3 | pip install airflow[s3] | S3KeySensor, S3PrefixSensor |
| s3 | ``pip install airflow[s3]`` | ``S3KeySensor``, ``S3PrefixSensor`` |
+-------------+------------------------------------+---------------------------------------+
| all | pip install airflow[all] | All Airflow features known to man |
| all | ``pip install airflow[all]`` | All Airflow features known to man |
+-------------+------------------------------------+---------------------------------------+
@ -82,3 +82,19 @@ its direction.
Note that you can also run "Celery Flower" a web UI build on top of Celery
to monitor your workers.
Web Authentication
''''''''''''''''''
By default, all gates are opened. An easy way to restrict access
to the web application is to do it at the network level, or by using
ssh tunnels.
However, it is possible to switch on
authentication and define exactly how your users should login
into your Airflow environment. Airflow uses ``flask_login`` and
exposes a set of hooks in the ``airflow.default_login`` module. You can
alter the content of this module by overriding it as a ``airflow_login``
module. To do this, you would typically copy/paste ``airflow.default_login``
in a ``airflow_login.py`` and put it directly in your ``PYTHONPATH``.

Просмотреть файл

@ -225,48 +225,57 @@ class WebUiTests(unittest.TestCase):
def tearDown(self):
pass
if 'MySqlOperator' in dir(operators):
# Only testing if the operator is installed
class MySqlTest(unittest.TestCase):
class MySqlTest(unittest.TestCase):
def setUp(self):
configuration.test_mode()
utils.initdb()
args = {'owner': 'airflow', 'start_date': datetime(2015, 1, 1)}
dag = DAG('hive_test', default_args=args)
self.dag = dag
def setUp(self):
configuration.test_mode()
utils.initdb()
args = {'owner': 'airflow', 'start_date': datetime(2015, 1, 1)}
dag = DAG('hive_test', default_args=args)
self.dag = dag
def mysql_operator_test(self):
sql = """
CREATE TABLE IF NOT EXISTS test_airflow (
dummy VARCHAR(50)
);
"""
t = operators.MySqlOperator(
task_id='basic_mysql', sql=sql, dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
def mysql_operator_test(self):
sql = """
CREATE TABLE IF NOT EXISTS test_airflow (
dummy VARCHAR(50)
);
"""
t = operators.MySqlOperator(
task_id='basic_mysql', sql=sql, dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
if 'PostgresOperator' in dir(operators):
# Only testing if the operator is installed
class PostgresTest(unittest.TestCase):
class PostgresTest(unittest.TestCase):
def setUp(self):
configuration.test_mode()
utils.initdb()
args = {'owner': 'airflow', 'start_date': datetime(2015, 1, 1)}
dag = DAG('hive_test', default_args=args)
self.dag = dag
def setUp(self):
configuration.test_mode()
utils.initdb()
args = {'owner': 'airflow', 'start_date': datetime(2015, 1, 1)}
dag = DAG('hive_test', default_args=args)
self.dag = dag
def postgres_operator_test(self):
sql = """
CREATE TABLE IF NOT EXISTS test_airflow (
dummy VARCHAR(50)
);
"""
t = operators.PostgresOperator(
task_id='basic_postgres', sql=sql, dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
def postgres_operator_test(self):
sql = """
CREATE TABLE IF NOT EXISTS test_airflow (
dummy VARCHAR(50)
);
"""
t = operators.PostgresOperator(
task_id='basic_postgres', sql=sql, dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
autocommitTask = operators.PostgresOperator(
task_id='basic_postgres_with_autocommit', sql=sql, dag=self.dag, autocommit=True)
autocommitTask.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
autocommitTask = operators.PostgresOperator(
task_id='basic_postgres_with_autocommit',
sql=sql,
dag=self.dag,
autocommit=True)
autocommitTask.run(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE,
force=True)
if __name__ == '__main__':