diff --git a/.gitignore b/.gitignore index ce543d37c3..f2f7ed7518 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.pyc +airflow/www/static/coverage/ .ipynb* docs/_* airflow.db @@ -16,3 +17,4 @@ secrets.py *.egg-info *.bkp .DS_Store +airflow_login.py diff --git a/TODO.md b/TODO.md index 79753d2564..5f1e83e08a 100644 --- a/TODO.md +++ b/TODO.md @@ -4,31 +4,22 @@ TODO * Backfill wizard #### unittests -* Increase coverage, now 80ish% +* Increase coverage, now 85ish% #### Command line * `airflow task_state dag_id task_id YYYY-MM-DD` #### More Operators! -* Sandbox the BashOperator -* S3Sensor * BaseDataTransferOperator * File2MySqlOperator -* DagTaskSensor for cross dag dependencies * PIG -#### Frontend -* - #### Backend +* Make authentication universal * Callbacks * Master auto dag refresh at time intervals * Prevent timezone chagne on import * Add decorator to timeout imports on master process [lib](https://github.com/pnpnpn/timeout-decorator) -* Make authentication universal - -#### Misc -* Write an hypervisor, looks for dead jobs without a heartbeat and kills #### Wishlist * Support for cron like synthax (0 * * * ) using croniter library diff --git a/airflow/__init__.py b/airflow/__init__.py index c6d6d05f61..94e0a4102c 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -1,3 +1,18 @@ -from models import DAG - __version__ = "0.4.3" + +''' +Authentication is implemented using flask_login and different environments can +implement their own login mechanisms by providing an `airflow_login` module +in their PYTHONPATH. airflow_login should be based off the +`airflow.www.login` +''' +try: + # Environment specific login + import airflow_login + login = airflow_login +except ImportError: + # Default login, no real authentication + from airflow import default_login + login = default_login + +from models import DAG diff --git a/airflow/default_login.py b/airflow/default_login.py new file mode 100644 index 0000000000..3a70a3de70 --- /dev/null +++ b/airflow/default_login.py @@ -0,0 +1,72 @@ +''' +Override this file to handle your authenticatin / login. + +Copy and alter this file and put in your PYTHONPATH as airflow_login.py, +the new module will override this one. +''' + +import flask_login +from flask_login import login_required, current_user, logout_user + +from flask import url_for, redirect + +from airflow import settings +from airflow import models + +DEFAULT_USERNAME = 'airflow' + +login_manager = flask_login.LoginManager() +login_manager.login_view = 'airflow.login' # Calls login() bellow +login_manager.login_message = None + + +class User(models.BaseUser): + + def is_active(self): + '''Required by flask_login''' + return True + + def is_authenticated(self): + '''Required by flask_login''' + return True + + def is_anonymous(self): + '''Required by flask_login''' + return False + + def data_profiling(self): + '''Provides access to data profiling tools''' + return True + + def is_superuser(self): + '''Access all the things''' + return True + +models.User = User # hack! + + +@login_manager.user_loader +def load_user(userid): + session = settings.Session() + user = session.query(User).filter(User.id == userid).first() + session.expunge_all() + session.commit() + session.close() + return user + + +def login(self, request): + session = settings.Session() + user = session.query(User).filter( + User.username == DEFAULT_USERNAME).first() + if not user: + user = User( + username=DEFAULT_USERNAME, + has_access=True, + is_superuser=True) + session.merge(user) + session.expunge_all() + session.commit() + session.close() + flask_login.login_user(user) + return redirect(request.args.get("next") or url_for("index")) diff --git a/airflow/models.py b/airflow/models.py index 060d3cbfd5..f802fcd535 100644 --- a/airflow/models.py +++ b/airflow/models.py @@ -183,34 +183,19 @@ class DagBag(object): return dag_ids -class User(Base): - """ - Eventually should be used for security purposes - """ +class BaseUser(Base): __tablename__ = "user" + id = Column(Integer, primary_key=True) username = Column(String(ID_LEN), unique=True) email = Column(String(500)) - def __init__(self, username=None, email=None): - self.username = username - self.email = email - def __repr__(self): return self.username def get_id(self): return unicode(self.id) - def is_active(self): - return True - - def is_authenticated(self): - return True - - def is_anonymous(self): - return False - class Connection(Base): """ @@ -1507,6 +1492,9 @@ class Chart(Base): iteration_no = Column(Integer, default=0) last_modified = Column(DateTime, default=datetime.now()) + def __repr__(self): + return self.label + class KnownEventType(Base): __tablename__ = "known_event_type" @@ -1534,3 +1522,6 @@ class KnownEvent(Base): cascade=False, cascade_backrefs=False, backref='known_events') description = Column(Text) + + def __repr__(self): + return self.label diff --git a/airflow/operators/email_operator.py b/airflow/operators/email_operator.py index d913da3c0f..6758b84cf3 100644 --- a/airflow/operators/email_operator.py +++ b/airflow/operators/email_operator.py @@ -6,6 +6,7 @@ from airflow.utils import apply_defaults class EmailOperator(BaseOperator): template_fields = ('subject', 'html_content') + template_ext = ('.html',) ui_color = '#e6faf9' __mapper_args__ = { diff --git a/airflow/www/app.py b/airflow/www/app.py index c326031059..f8d57b93d8 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -4,10 +4,8 @@ import dateutil.parser import json import logging import os -import re import socket import sys -import urllib2 from flask import Flask, url_for, Markup, Blueprint, redirect, flash, Response from flask.ext.admin import Admin, BaseView, expose, AdminIndexView @@ -28,18 +26,58 @@ import markdown import chartkick import airflow +login_required = airflow.login.login_required +current_user = airflow.login.current_user +logout_user = airflow.login.logout_user + from airflow.settings import Session from airflow import jobs from airflow import models +from airflow import login from airflow.models import State from airflow import settings from airflow.configuration import conf from airflow import utils from airflow.www import utils as wwwutils -from airflow.www.login import login_manager -import flask_login -from flask_login import login_required +from functools import wraps + +AUTHENTICATE = conf.getboolean('master', 'AUTHENTICATE') +if AUTHENTICATE is False: + login_required = lambda x: x + + +def superuser_required(f): + ''' + Decorator for views requiring superuser access + ''' + @wraps(f) + def decorated_function(*args, **kwargs): + if ( + not AUTHENTICATE or + (not current_user.is_anonymous() and current_user.is_superuser()) + ): + return f(*args, **kwargs) + else: + flash("This page requires superuser priviledges", "error") + return redirect(url_for('admin.index')) + return decorated_function + +def data_profiling_required(f): + ''' + Decorator for views requiring data profiling access + ''' + @wraps(f) + def decorated_function(*args, **kwargs): + if ( + not AUTHENTICATE or + (not current_user.is_anonymous() and current_user.data_profiling()) + ): + return f(*args, **kwargs) + else: + flash("This page requires data profiling priviledges", "error") + return redirect(url_for('admin.index')) + return decorated_function QUERY_LIMIT = 100000 CHART_LIMIT = 200000 @@ -50,9 +88,6 @@ special_attrs = { 'bash_command': BashLexer, } -AUTHENTICATE = conf.getboolean('master', 'AUTHENTICATE') -if AUTHENTICATE is False: - login_required = lambda x: x dagbag = models.DagBag(os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))) utils.pessimistic_connection_handling() @@ -60,7 +95,7 @@ utils.pessimistic_connection_handling() app = Flask(__name__) app.config['SQLALCHEMY_POOL_RECYCLE'] = 3600 -login_manager.init_app(app) +login.login_manager.init_app(app) app.secret_key = 'airflowified' cache = Cache( @@ -113,6 +148,7 @@ class HomeView(AdminIndexView): Basic home view, just showing the README.md file """ @expose("/") + @login_required def index(self): dags = dagbag.dags.values() dags = [dag for dag in dags if not dag.parent_dag] @@ -125,12 +161,6 @@ admin = Admin( index_view=HomeView(name='DAGs'), template_mode='bootstrap3') -admin.add_link( - base.MenuLink( - category='Data Profiling', - name='Ad Hoc Query', - url='/admin/airflow/query')) - class Airflow(BaseView): @@ -138,73 +168,12 @@ class Airflow(BaseView): return False @expose('/') + @login_required def index(self): return self.render('airflow/dags.html') - @expose('/query') - @login_required - @wwwutils.gzipped - def query(self): - session = settings.Session() - dbs = session.query(models.Connection).order_by( - models.Connection.conn_id) - db_choices = [(db.conn_id, db.conn_id) for db in dbs if db.get_hook()] - conn_id_str = request.args.get('conn_id') - csv = request.args.get('csv') == "true" - sql = request.args.get('sql') - - class QueryForm(Form): - conn_id = SelectField("Layout", choices=db_choices) - sql = TextAreaField("SQL", widget=wwwutils.AceEditorWidget()) - data = { - 'conn_id': conn_id_str, - 'sql': sql, - } - results = None - has_data = False - error = False - if conn_id_str: - db = [db for db in dbs if db.conn_id == conn_id_str][0] - hook = db.get_hook() - try: - df = hook.get_pandas_df(wwwutils.limit_sql(sql, QUERY_LIMIT)) - # df = hook.get_pandas_df(sql) - has_data = len(df) > 0 - df = df.fillna('') - results = df.to_html( - classes="table table-bordered table-striped no-wrap", - index=False, - na_rep='', - ) if has_data else '' - except Exception as e: - flash(str(e), 'error') - error = True - - if has_data and len(df) == QUERY_LIMIT: - flash( - "Query output truncated at " + str(QUERY_LIMIT) + - " rows", 'info') - - if not has_data and error: - flash('No data', 'error') - - if csv: - return Response( - response=df.to_csv(index=False), - status=200, - mimetype="application/text") - - form = QueryForm(request.form, data=data) - session.commit() - session.close() - return self.render( - 'airflow/query.html', form=form, - title="Ad Hoc Query", - results=results or '', - has_data=has_data) - @expose('/chart_data') - @login_required + @data_profiling_required @wwwutils.gzipped @cache.cached(timeout=3600, key_prefix=wwwutils.make_cache_key) def chart_data(self): @@ -472,7 +441,7 @@ class Airflow(BaseView): mimetype="application/json") @expose('/chart') - @login_required + @data_profiling_required def chart(self): session = settings.Session() chart_id = request.args.get('chart_id') @@ -497,6 +466,7 @@ class Airflow(BaseView): label=chart.label) @expose('/dag_stats') + @login_required def dag_stats(self): states = [State.SUCCESS, State.RUNNING, State.FAILED] task_ids = [] @@ -537,6 +507,7 @@ class Airflow(BaseView): status=200, mimetype="application/json") @expose('/code') + @login_required def code(self): dag_id = request.args.get('dag_id') dag = dagbag.dags[dag_id] @@ -548,6 +519,7 @@ class Airflow(BaseView): 'airflow/dag_code.html', html_code=html_code, dag=dag, title=title) @expose('/circles') + @login_required def circles(self): return self.render( 'airflow/circles.html') @@ -570,33 +542,6 @@ class Airflow(BaseView): 'airflow/code.html', code_html=code_html, title=title, subtitle=cfg_loc) - @expose('/conf') - @login_required - def conf(self): - from airflow import configuration - raw = request.args.get('raw') == "true" - title = "Airflow Configuration" - subtitle = configuration.AIRFLOW_CONFIG - f = open(configuration.AIRFLOW_CONFIG, 'r') - config = f.read() - - f.close() - if raw: - return Response( - response=config, - status=200, - mimetype="application/text") - else: - code_html = Markup(highlight( - config, - IniLexer(), # Lexer call - HtmlFormatter(noclasses=True)) - ) - return self.render( - 'airflow/code.html', - pre_subtitle=settings.HEADER + " v" + airflow.__version__, - code_html=code_html, title=title, subtitle=subtitle) - @expose('/noaccess') def noaccess(self): return self.render('airflow/noaccess.html') @@ -630,53 +575,25 @@ class Airflow(BaseView): @expose('/headers') def headers(self): d = {k: v for k, v in request.headers} + d['is_superuser'] = current_user.is_superuser() + d['data_profiling'] = current_user.data_profiling() + d['is_anonymous'] = current_user.is_anonymous() + d['is_authenticated'] = current_user.is_authenticated() return Response( response=json.dumps(d, indent=4), status=200, mimetype="application/json") @expose('/login') def login(self): - session = settings.Session() - roles = [ - 'airpal_topsecret.engineering.airbnb.com', - 'hadoop_user.engineering.airbnb.com', - 'analytics.engineering.airbnb.com', - 'nerds.engineering.airbnb.com', - ] - if 'X-Internalauth-Username' not in request.headers: - return redirect(url_for('airflow.noaccess')) - username = request.headers.get('X-Internalauth-Username') - groups = request.headers.get( - 'X-Internalauth-Groups').lower().split(',') - has_access = any([g in roles for g in groups]) - - d = {k: v for k, v in request.headers} - cookie = urllib2.unquote(d.get('Cookie', '')) - mailsrch = re.compile( - r'[\w\-][\w\-\.]+@[\w\-][\w\-\.]+[a-zA-Z]{1,4}') - res = mailsrch.findall(cookie) - email = res[0] if res else '' - - if has_access: - user = session.query(models.User).filter( - models.User.username == username).first() - if not user: - user = models.User(username=username) - user.email = email - session.merge(user) - session.commit() - flask_login.login_user(user) - session.commit() - session.close() - return redirect(request.args.get("next") or url_for("index")) - return redirect('/') + return login.login(self, request) @expose('/logout') def logout(self): - flask_login.logout_user() + logout_user() return redirect('/admin') @expose('/rendered') + @login_required def rendered(self): dag_id = request.args.get('dag_id') task_id = request.args.get('task_id') @@ -714,6 +631,7 @@ class Airflow(BaseView): title=title,) @expose('/log') + @login_required def log(self): BASE_LOG_FOLDER = os.path.expanduser( conf.get('core', 'BASE_LOG_FOLDER')) @@ -767,6 +685,7 @@ class Airflow(BaseView): execution_date=execution_date, form=form) @expose('/task') + @login_required def task(self): dag_id = request.args.get('dag_id') task_id = request.args.get('task_id') @@ -810,6 +729,7 @@ class Airflow(BaseView): dag=dag, title=title) @expose('/action') + @login_required def action(self): action = request.args.get('action') dag_id = request.args.get('dag_id') @@ -926,6 +846,7 @@ class Airflow(BaseView): return response @expose('/tree') + @login_required @wwwutils.gzipped def tree(self): dag_id = request.args.get('dag_id') @@ -1015,6 +936,7 @@ class Airflow(BaseView): dag=dag, data=data) @expose('/graph') + @login_required def graph(self): session = settings.Session() dag_id = request.args.get('dag_id') @@ -1089,6 +1011,7 @@ class Airflow(BaseView): edges=json.dumps(edges, indent=2),) @expose('/duration') + @login_required def duration(self): session = settings.Session() dag_id = request.args.get('dag_id') @@ -1120,6 +1043,7 @@ class Airflow(BaseView): ) @expose('/landing_times') + @login_required def landing_times(self): session = settings.Session() dag_id = request.args.get('dag_id') @@ -1152,6 +1076,7 @@ class Airflow(BaseView): ) @expose('/gantt') + @login_required def gantt(self): session = settings.Session() @@ -1220,18 +1145,96 @@ class Airflow(BaseView): admin.add_view(Airflow(name='DAGs')) -# ------------------------------------------------ -# Leveraging the admin for CRUD and browse on models -# ------------------------------------------------ - class LoginMixin(object): def is_accessible(self): - return AUTHENTICATE is False or \ - flask_login.current_user.is_authenticated() + return ( + not AUTHENTICATE or + (not current_user.is_anonymous() and current_user.is_authenticated()) + ) -class ModelViewOnly(ModelView): +class DataProfilingMixin(object): + def is_accessible(self): + return ( + not AUTHENTICATE or + (not current_user.is_anonymous() and current_user.data_profiling()) + ) + + +class SuperUserMixin(object): + def is_accessible(self): + return ( + not AUTHENTICATE or + (not current_user.is_anonymous() and current_user.is_superuser()) + ) + + +class QueryView(DataProfilingMixin, BaseView): + @expose('/') + @wwwutils.gzipped + def query(self): + session = settings.Session() + dbs = session.query(models.Connection).order_by( + models.Connection.conn_id) + db_choices = [(db.conn_id, db.conn_id) for db in dbs if db.get_hook()] + conn_id_str = request.args.get('conn_id') + csv = request.args.get('csv') == "true" + sql = request.args.get('sql') + + class QueryForm(Form): + conn_id = SelectField("Layout", choices=db_choices) + sql = TextAreaField("SQL", widget=wwwutils.AceEditorWidget()) + data = { + 'conn_id': conn_id_str, + 'sql': sql, + } + results = None + has_data = False + error = False + if conn_id_str: + db = [db for db in dbs if db.conn_id == conn_id_str][0] + hook = db.get_hook() + try: + df = hook.get_pandas_df(wwwutils.limit_sql(sql, QUERY_LIMIT)) + # df = hook.get_pandas_df(sql) + has_data = len(df) > 0 + df = df.fillna('') + results = df.to_html( + classes="table table-bordered table-striped no-wrap", + index=False, + na_rep='', + ) if has_data else '' + except Exception as e: + flash(str(e), 'error') + error = True + + if has_data and len(df) == QUERY_LIMIT: + flash( + "Query output truncated at " + str(QUERY_LIMIT) + + " rows", 'info') + + if not has_data and error: + flash('No data', 'error') + + if csv: + return Response( + response=df.to_csv(index=False), + status=200, + mimetype="application/text") + + form = QueryForm(request.form, data=data) + session.commit() + session.close() + return self.render( + 'airflow/query.html', form=form, + title="Ad Hoc Query", + results=results or '', + has_data=has_data) +admin.add_view(QueryView(name='Ad Hoc Query', category="Data Profiling")) + + +class ModelViewOnly(LoginMixin, ModelView): """ Modifying the base ModelView class for non edit, browse only operations """ @@ -1310,14 +1313,7 @@ mv = TaskInstanceModelView( admin.add_view(mv) -admin.add_link( - base.MenuLink( - category='Admin', - name='Configuration', - url='/admin/airflow/conf')) - - -class ConnectionModelView(LoginMixin, ModelView): +class ConnectionModelView(SuperUserMixin, ModelView): column_list = ('conn_id', 'conn_type', 'host', 'port') form_choices = { 'conn_type': [ @@ -1340,13 +1336,13 @@ mv = ConnectionModelView( admin.add_view(mv) -class UserModelView(LoginMixin, ModelView): +class UserModelView(SuperUserMixin, ModelView): column_default_sort = 'username' mv = UserModelView(models.User, Session, name="Users", category="Admin") admin.add_view(mv) -class ReloadTaskView(BaseView): +class ReloadTaskView(SuperUserMixin, BaseView): @expose('/') def index(self): logging.info("Reloading the dags") @@ -1356,7 +1352,36 @@ class ReloadTaskView(BaseView): admin.add_view(ReloadTaskView(name='Reload DAGs', category="Admin")) -class DagModelView(ModelView): +class ConfigurationView(SuperUserMixin, BaseView): + @expose('/') + def conf(self): + from airflow import configuration + raw = request.args.get('raw') == "true" + title = "Airflow Configuration" + subtitle = configuration.AIRFLOW_CONFIG + f = open(configuration.AIRFLOW_CONFIG, 'r') + config = f.read() + + f.close() + if raw: + return Response( + response=config, + status=200, + mimetype="application/text") + else: + code_html = Markup(highlight( + config, + IniLexer(), # Lexer call + HtmlFormatter(noclasses=True)) + ) + return self.render( + 'airflow/code.html', + pre_subtitle=settings.HEADER + " v" + airflow.__version__, + code_html=code_html, title=title, subtitle=subtitle) +admin.add_view(ConfigurationView(name='Configuration', category="Admin")) + + +class DagModelView(SuperUserMixin, ModelView): column_list = ('dag_id', 'is_paused') column_editable_list = ('is_paused',) mv = DagModelView( @@ -1375,7 +1400,7 @@ def label_link(v, c, m, p): return Markup("{m.label}".format(**locals())) -class ChartModelView(LoginMixin, ModelView): +class ChartModelView(DataProfilingMixin, ModelView): form_columns = ( 'label', 'owner', @@ -1461,8 +1486,8 @@ class ChartModelView(LoginMixin, ModelView): model.iteration_no = 0 else: model.iteration_no += 1 - if AUTHENTICATE and not model.user_id and flask_login.current_user: - model.user_id = flask_login.current_user.id + if AUTHENTICATE and not model.user_id and current_user: + model.user_id = current_user.id model.last_modified = datetime.now() mv = ChartModelView( @@ -1482,7 +1507,7 @@ admin.add_link( url='https://github.com/mistercrunch/Airflow')) -class KnowEventView(LoginMixin, ModelView): +class KnowEventView(DataProfilingMixin, ModelView): form_columns = ( 'label', 'event_type', @@ -1497,7 +1522,7 @@ mv = KnowEventView( admin.add_view(mv) -class KnowEventTypeView(LoginMixin, ModelView): +class KnowEventTypeView(DataProfilingMixin, ModelView): pass ''' mv = KnowEventTypeView( diff --git a/airflow/www/login.py b/airflow/www/login.py deleted file mode 100644 index e1cb546d4f..0000000000 --- a/airflow/www/login.py +++ /dev/null @@ -1,21 +0,0 @@ -import flask_login - -from airflow.models import User -from airflow import settings - -login_manager = flask_login.LoginManager() - -@login_manager.user_loader -def load_user(userid): - session = settings.Session() - user = session.query(User).filter(User.id == userid).first() - #if not user: - # raise Exception(userid) - session.expunge_all() - session.commit() - session.close() - return user - - -login_manager.login_view = 'airflow.login' -login_manager.login_message = None diff --git a/docs/installation.rst b/docs/installation.rst index fa2d6ec34d..3264f37ce8 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -5,10 +5,10 @@ working towards a production grade environment is a bit more work. Extra Packages '''''''''''''' -The `airflow` PyPI basic package only installs what's needed to get started. +The ``airflow`` PyPI basic package only installs what's needed to get started. Subpackages can be installed depending on what will be useful in your environment. For instance, if you don't need connectivity with Postgres, -you won't have to go through the trouble of install the `postgres-devel` yum +you won't have to go through the trouble of install the ``postgres-devel`` yum package, or whatever equivalent on the distribution you are using. Behind the scene, we do conditional imports on operators that require @@ -19,17 +19,17 @@ Here's the list of the subpackages and that they enable: +-------------+------------------------------------+---------------------------------------+ | subpackage | install command | enables | +=============+====================================+=======================================+ -| mysql | pip install airflow[mysql] | MySQL operators and hook, support as | +| mysql | ``pip install airflow[mysql]`` | MySQL operators and hook, support as | | | | an Airflow backend | +-------------+------------------------------------+---------------------------------------+ -| postgres | pip install airflow[postgres] | Postgres operators and hook, support | +| postgres | ``pip install airflow[postgres]`` | Postgres operators and hook, support | | | | as an Airflow backend | +-------------+------------------------------------+---------------------------------------+ -| samba | pip install airflow[samba] | Hive2SambaOperator | +| samba | ``pip install airflow[samba]`` | ``Hive2SambaOperator`` | +-------------+------------------------------------+---------------------------------------+ -| s3 | pip install airflow[s3] | S3KeySensor, S3PrefixSensor | +| s3 | ``pip install airflow[s3]`` | ``S3KeySensor``, ``S3PrefixSensor`` | +-------------+------------------------------------+---------------------------------------+ -| all | pip install airflow[all] | All Airflow features known to man | +| all | ``pip install airflow[all]`` | All Airflow features known to man | +-------------+------------------------------------+---------------------------------------+ @@ -82,3 +82,19 @@ its direction. Note that you can also run "Celery Flower" a web UI build on top of Celery to monitor your workers. + + +Web Authentication +'''''''''''''''''' + +By default, all gates are opened. An easy way to restrict access +to the web application is to do it at the network level, or by using +ssh tunnels. + +However, it is possible to switch on +authentication and define exactly how your users should login +into your Airflow environment. Airflow uses ``flask_login`` and +exposes a set of hooks in the ``airflow.default_login`` module. You can +alter the content of this module by overriding it as a ``airflow_login`` +module. To do this, you would typically copy/paste ``airflow.default_login`` +in a ``airflow_login.py`` and put it directly in your ``PYTHONPATH``. diff --git a/tests/core.py b/tests/core.py index 19226a370c..e5462eacc7 100644 --- a/tests/core.py +++ b/tests/core.py @@ -225,48 +225,57 @@ class WebUiTests(unittest.TestCase): def tearDown(self): pass +if 'MySqlOperator' in dir(operators): + # Only testing if the operator is installed + class MySqlTest(unittest.TestCase): -class MySqlTest(unittest.TestCase): + def setUp(self): + configuration.test_mode() + utils.initdb() + args = {'owner': 'airflow', 'start_date': datetime(2015, 1, 1)} + dag = DAG('hive_test', default_args=args) + self.dag = dag - def setUp(self): - configuration.test_mode() - utils.initdb() - args = {'owner': 'airflow', 'start_date': datetime(2015, 1, 1)} - dag = DAG('hive_test', default_args=args) - self.dag = dag + def mysql_operator_test(self): + sql = """ + CREATE TABLE IF NOT EXISTS test_airflow ( + dummy VARCHAR(50) + ); + """ + t = operators.MySqlOperator( + task_id='basic_mysql', sql=sql, dag=self.dag) + t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True) - def mysql_operator_test(self): - sql = """ - CREATE TABLE IF NOT EXISTS test_airflow ( - dummy VARCHAR(50) - ); - """ - t = operators.MySqlOperator( - task_id='basic_mysql', sql=sql, dag=self.dag) - t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True) +if 'PostgresOperator' in dir(operators): + # Only testing if the operator is installed + class PostgresTest(unittest.TestCase): -class PostgresTest(unittest.TestCase): + def setUp(self): + configuration.test_mode() + utils.initdb() + args = {'owner': 'airflow', 'start_date': datetime(2015, 1, 1)} + dag = DAG('hive_test', default_args=args) + self.dag = dag - def setUp(self): - configuration.test_mode() - utils.initdb() - args = {'owner': 'airflow', 'start_date': datetime(2015, 1, 1)} - dag = DAG('hive_test', default_args=args) - self.dag = dag + def postgres_operator_test(self): + sql = """ + CREATE TABLE IF NOT EXISTS test_airflow ( + dummy VARCHAR(50) + ); + """ + t = operators.PostgresOperator( + task_id='basic_postgres', sql=sql, dag=self.dag) + t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True) - def postgres_operator_test(self): - sql = """ - CREATE TABLE IF NOT EXISTS test_airflow ( - dummy VARCHAR(50) - ); - """ - t = operators.PostgresOperator( - task_id='basic_postgres', sql=sql, dag=self.dag) - t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True) - - autocommitTask = operators.PostgresOperator( - task_id='basic_postgres_with_autocommit', sql=sql, dag=self.dag, autocommit=True) - autocommitTask.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True) + autocommitTask = operators.PostgresOperator( + task_id='basic_postgres_with_autocommit', + sql=sql, + dag=self.dag, + autocommit=True) + autocommitTask.run( + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, + force=True) if __name__ == '__main__':