Rename DatabaseConnection to Connection

This commit is contained in:
Maxime 2015-01-11 00:43:01 +00:00
Родитель 46645898f2
Коммит 76af64332d
11 изменённых файлов: 79 добавлений и 79 удалений

Просмотреть файл

@ -28,8 +28,8 @@ WORKER_LOG_SERVER_PORT = 8793
[hooks]
HIVE_HOME_PY: '/usr/lib/hive/lib/py'
PRESTO_DEFAULT_DBID: presto_default
HIVE_DEFAULT_DBID: hive_default
PRESTO_DEFAULT_CONN_ID: presto_default
HIVE_DEFAULT_CONN_ID: hive_default
[misc]
RUN_AS_MASTER: True

Просмотреть файл

@ -250,28 +250,28 @@ def initdb(args):
# Creating the local_mysql DB connection
session = settings.Session()
session.query(models.DatabaseConnection).delete()
session.query(models.Connection).delete()
session.add(
models.DatabaseConnection(
db_id='local_mysql', db_type='mysql',
models.Connection(
conn_id='local_mysql', db_type='mysql',
host='localhost', login='airflow', password='airflow',
schema='airflow'))
session.commit()
session.add(
models.DatabaseConnection(
db_id='mysql_default', db_type='mysql',
models.Connection(
conn_id='mysql_default', db_type='mysql',
host='localhost', login='airflow', password='airflow',
schema='airflow'))
session.commit()
session.add(
models.DatabaseConnection(
db_id='presto_default', db_type='presto',
models.Connection(
conn_id='presto_default', db_type='presto',
host='localhost',
schema='hive', port=10001))
session.commit()
session.add(
models.DatabaseConnection(
db_id='hive_default', db_type='hive',
models.Connection(
conn_id='hive_default', db_type='hive',
host='localhost',
schema='default', port=10000))
session.commit()

Просмотреть файл

@ -4,7 +4,7 @@ import subprocess
import sys
from tempfile import NamedTemporaryFile
from airflow.models import DatabaseConnection
from airflow.models import Connection
from airflow.configuration import conf
from airflow import settings
@ -21,13 +21,13 @@ from airflow.hooks.base_hook import BaseHook
class HiveHook(BaseHook):
def __init__(self,
hive_dbid=conf.get('hooks', 'HIVE_DEFAULT_DBID')):
hive_conn_id=conf.get('hooks', 'HIVE_DEFAULT_CONN_ID')):
session = settings.Session()
db = session.query(
DatabaseConnection).filter(
DatabaseConnection.db_id == hive_dbid)
Connection).filter(
Connection.conn_id == hive_conn_id)
if db.count() == 0:
raise Exception("The dbid you provided isn't defined")
raise Exception("The conn_id you provided isn't defined")
else:
db = db.all()[0]
self.host = db.host

Просмотреть файл

@ -1,13 +1,14 @@
import MySQLdb
from airflow import settings
from airflow.models import DatabaseConnection
from airflow.models import Connection
class MySqlHook(object):
def __init__(
self, host=None, login=None, psw=None, db=None, mysql_dbid=None):
if not mysql_dbid:
self, host=None, login=None,
psw=None, db=None, mysql_conn_id=None):
if not mysql_conn_id:
self.host = host
self.login = login
self.psw = psw
@ -15,8 +16,8 @@ class MySqlHook(object):
else:
session = settings.Session()
db = session.query(
DatabaseConnection).filter(
DatabaseConnection.db_id == mysql_dbid)
Connection).filter(
Connection.conn_id == mysql_conn_id)
if db.count() == 0:
raise Exception("The mysql_dbid you provided isn't defined")
else:

Просмотреть файл

@ -1,8 +1,6 @@
import subprocess
from airflow import settings
from airflow.configuration import conf
from airflow.models import DatabaseConnection
from airflow.models import Connection
from airflow.hooks.base_hook import BaseHook
from airflow.hooks.presto.presto_client import PrestoClient
@ -13,13 +11,14 @@ class PrestoHook(BaseHook):
"""
Interact with Presto!
"""
def __init__(self, presto_dbid=conf.get('hooks', 'PRESTO_DEFAULT_DBID')):
def __init__(
self, presto_conn_id=conf.get('hooks', 'PRESTO_DEFAULT_CONN_ID')):
session = settings.Session()
db = session.query(
DatabaseConnection).filter(
DatabaseConnection.db_id == presto_dbid)
Connection).filter(
Connection.conn_id == presto_conn_id)
if db.count() == 0:
raise Exception("The presto_dbid you provided isn't defined")
raise Exception("The presto_conn_id you provided isn't defined")
else:
db = db.all()[0]
self.host = db.host

Просмотреть файл

@ -2,9 +2,9 @@ from airflow.configuration import conf
def max_partition(
table, schema="default",
hive_dbid=conf.get('hooks', 'HIVE_DEFAULT_DBID')):
hive_conn_id=conf.get('hooks', 'HIVE_DEFAULT_CONN_ID')):
from airflow.hooks.hive_hook import HiveHook
if '.' in table:
schema, table = table.split('.')
hh = HiveHook(hive_dbid=hive_dbid)
hh = HiveHook(hive_conn_id=hive_conn_id)
return hh.max_partition(schema=schema, table_name=table)

Просмотреть файл

@ -130,18 +130,18 @@ class User(Base):
return False
class DatabaseConnection(Base):
class Connection(Base):
"""
Placeholder to store information about different database instances
connection information. The idea here is that scripts use references to
database instances (db_id) instead of hard coding hostname, logins and
database instances (conn_id) instead of hard coding hostname, logins and
passwords when using operators or hooks.
"""
__tablename__ = "db_connection"
__tablename__ = "connection"
id = Column(Integer(), primary_key=True)
db_id = Column(String(ID_LEN), unique=True)
db_type = Column(String(500))
conn_id = Column(String(ID_LEN), unique=True)
conn_type = Column(String(500))
host = Column(String(500))
schema = Column(String(500))
login = Column(String(500))
@ -150,11 +150,11 @@ class DatabaseConnection(Base):
extra = Column(String(5000))
def __init__(
self, db_id=None, db_type=None,
self, conn_id=None, conn_type=None,
host=None, login=None, password=None,
schema=None, port=None):
self.db_id = db_id
self.db_type = db_type
self.conn_id = conn_id
self.conn_type = conn_type
self.host = host
self.login = login
self.password = password
@ -163,15 +163,15 @@ class DatabaseConnection(Base):
def get_hook(self):
from airflow import hooks
if self.db_type == 'mysql':
return hooks.MySqlHook(mysql_dbid=self.db_id)
elif self.db_type == 'hive':
return hooks.HiveHook(hive_dbid=self.db_id)
elif self.db_type == 'presto':
return hooks.PrestoHook(presto_dbid=self.db_id)
if self.conn_type == 'mysql':
return hooks.MySqlHook(mysql_conn_id=self.conn_id)
elif self.conn_type == 'hive':
return hooks.HiveHook(hive_conn_id=self.conn_id)
elif self.conn_type == 'presto':
return hooks.PrestoHook(presto_conn_id=self.conn_id)
def __repr__(self):
return self.db_id
return self.conn_id
class DagPickle(Base):
@ -1150,7 +1150,7 @@ class Chart(Base):
id = Column(Integer, primary_key=True)
label = Column(String(200))
db_id = Column(String(ID_LEN), ForeignKey('db_connection.db_id'))
conn_id = Column(String(ID_LEN), ForeignKey('connection.conn_id'))
user_id = Column(Integer(), ForeignKey('user.id'),)
chart_type = Column(String(100), default="line")
sql_layout = Column(String(50), default="series")
@ -1162,5 +1162,5 @@ class Chart(Base):
default_params = Column(String(5000), default="{}")
owner = relationship("User", cascade=False, cascade_backrefs=False)
x_is_date = Column(Boolean, default=True)
db = relationship("DatabaseConnection")
db = relationship("Connection")
iteration_no = Column(Integer, default=0)

Просмотреть файл

@ -12,8 +12,8 @@ class HiveOperator(BaseOperator):
:param hql: the hql to be executed
:type hql: string
:param hive_dbid: reference to the Hive database
:type hive_dbid: string
:param hive_conn_id: reference to the Hive database
:type hive_conn_id: string
"""
__mapper_args__ = {
@ -25,12 +25,12 @@ class HiveOperator(BaseOperator):
@apply_defaults
def __init__(
self, hql,
hive_dbid=conf.get('hooks', 'HIVE_DEFAULT_DBID'),
hive_conn_id=conf.get('hooks', 'HIVE_DEFAULT_CONN_ID'),
*args, **kwargs):
super(HiveOperator, self).__init__(*args, **kwargs)
self.hive_dbid = hive_dbid
self.hook = HiveHook(hive_dbid=hive_dbid)
self.hive_conn_id = hive_conn_id
self.hook = HiveHook(hive_conn_id=hive_conn_id)
self.hql = hql
def execute(self, execution_date):

Просмотреть файл

@ -17,15 +17,15 @@ class MySqlOperator(BaseOperator):
template_ext = ('.sql',)
@apply_defaults
def __init__(self, sql, mysql_dbid, *args, **kwargs):
def __init__(self, sql, mysql_conn_id, *args, **kwargs):
"""
Parameters:
mysql_dbid: reference to a specific mysql database
mysql_conn_id: reference to a specific mysql database
sql: the sql code you to be executed
"""
super(MySqlOperator, self).__init__(*args, **kwargs)
self.hook = MySqlHook(mysql_dbid=mysql_dbid)
self.hook = MySqlHook(mysql_conn_id=mysql_conn_id)
self.sql = sql
def execute(self, execution_date):

Просмотреть файл

@ -6,7 +6,7 @@ from airflow import settings
from airflow.configuration import conf
from airflow.hooks import HiveHook
from airflow.models import BaseOperator
from airflow.models import DatabaseConnection as DB
from airflow.models import Connection as DB
from airflow.models import State
from airflow.models import TaskInstance
from airflow.utils import apply_defaults
@ -56,17 +56,17 @@ class SqlSensor(BaseSensorOperator):
}
@apply_defaults
def __init__(self, db_id, sql, *args, **kwargs):
def __init__(self, conn_id, sql, *args, **kwargs):
super(SqlSensor, self).__init__(*args, **kwargs)
self.sql = sql
self.db_id = db_id
self.conn_id = conn_id
session = settings.Session()
db = session.query(DB).filter(DB.db_id==db_id).all()
db = session.query(DB).filter(DB.conn_id==conn_id).all()
if not db:
raise Exception("db_id doesn't exist in the repository")
raise Exception("conn_id doesn't exist in the repository")
self.hook = db[0].get_hook()
session.commit()
session.close()
@ -132,14 +132,14 @@ class HivePartitionSensor(BaseSensorOperator):
def __init__(
self,
table, partition="ds='{{ ds }}'",
hive_dbid=conf.get('hooks', 'HIVE_DEFAULT_DBID'),
hive_conn_id=conf.get('hooks', 'HIVE_DEFAULT_CONN_ID'),
schema='default',
*args, **kwargs):
super(HivePartitionSensor, self).__init__(*args, **kwargs)
if '.' in table:
schema, table = table.split('.')
self.hive_dbid = hive_dbid
self.hook = HiveHook(hive_dbid=hive_dbid)
self.hive_conn_id = hive_conn_id
self.hook = HiveHook(hive_conn_id=hive_conn_id)
self.table = table
self.partition = partition
self.schema = schema

Просмотреть файл

@ -134,24 +134,24 @@ class Airflow(BaseView):
@wwwutils.gzipped
def query(self):
session = settings.Session()
dbs = session.query(models.DatabaseConnection).order_by(
models.DatabaseConnection.db_id)
db_choices = [(db.db_id, db.db_id) for db in dbs]
db_id_str = request.args.get('db_id')
dbs = session.query(models.Connection).order_by(
models.Connection.conn_id)
db_choices = [(db.conn_id, db.conn_id) for db in dbs]
conn_id_str = request.args.get('conn_id')
sql = request.args.get('sql')
class QueryForm(Form):
db_id = SelectField("Layout", choices=db_choices)
conn_id = SelectField("Layout", choices=db_choices)
sql = TextAreaField("SQL", widget=wwwutils.AceEditorWidget())
data = {
'db_id': db_id_str,
'conn_id': conn_id_str,
'sql': sql,
}
results = None
has_data = False
error = False
if db_id_str:
db = [db for db in dbs if db.db_id == db_id_str][0]
if conn_id_str:
db = [db for db in dbs if db.conn_id == conn_id_str][0]
hook = db.get_hook()
try:
# df = hook.get_pandas_df(wwwutils.limit_sql(sql, QUERY_LIMIT))
@ -193,7 +193,7 @@ class Airflow(BaseView):
chart_id = request.args.get('chart_id')
chart = session.query(models.Chart).filter_by(id=chart_id).all()[0]
db = session.query(
models.DatabaseConnection).filter_by(db_id=chart.db_id).all()[0]
models.Connection).filter_by(conn_id=chart.conn_id).all()[0]
session.expunge_all()
payload = {}
@ -1022,8 +1022,8 @@ mv = UserModelView(models.User, Session, name="Users", category="Admin")
admin.add_view(mv)
class DatabaseConnectionModelView(LoginMixin, ModelView):
column_list = ('db_id', 'db_type', 'host', 'port')
class ConnectionModelView(LoginMixin, ModelView):
column_list = ('conn_id', 'db_type', 'host', 'port')
form_choices = {
'db_type': [
('hive', 'Hive',),
@ -1034,9 +1034,9 @@ class DatabaseConnectionModelView(LoginMixin, ModelView):
('ftp', 'FTP',),
]
}
mv = DatabaseConnectionModelView(
models.DatabaseConnection, Session,
name="Database Connections", category="Admin")
mv = ConnectionModelView(
models.Connection, Session,
name="Connections", category="Admin")
admin.add_view(mv)
@ -1085,12 +1085,12 @@ class ChartModelView(LoginMixin, ModelView):
'sql',
'default_params',)
column_list = (
'label', 'db_id', 'chart_type', 'owner',
'label', 'conn_id', 'chart_type', 'owner',
'show_datatable', 'show_sql',)
column_formatters = dict(label=label_link)
create_template = 'airflow/chart/create.html'
edit_template = 'airflow/chart/edit.html'
column_filters = ('owner.username', 'db_id',)
column_filters = ('owner.username', 'conn_id',)
column_searchable_list = ('owner.username', 'label', 'sql')
column_descriptions = {
'label': "Can include {{ templated_fields }} and {{ macros }}",