Rename DatabaseConnection to Connection
This commit is contained in:
Родитель
46645898f2
Коммит
76af64332d
|
@ -28,8 +28,8 @@ WORKER_LOG_SERVER_PORT = 8793
|
|||
|
||||
[hooks]
|
||||
HIVE_HOME_PY: '/usr/lib/hive/lib/py'
|
||||
PRESTO_DEFAULT_DBID: presto_default
|
||||
HIVE_DEFAULT_DBID: hive_default
|
||||
PRESTO_DEFAULT_CONN_ID: presto_default
|
||||
HIVE_DEFAULT_CONN_ID: hive_default
|
||||
|
||||
[misc]
|
||||
RUN_AS_MASTER: True
|
||||
|
|
|
@ -250,28 +250,28 @@ def initdb(args):
|
|||
|
||||
# Creating the local_mysql DB connection
|
||||
session = settings.Session()
|
||||
session.query(models.DatabaseConnection).delete()
|
||||
session.query(models.Connection).delete()
|
||||
session.add(
|
||||
models.DatabaseConnection(
|
||||
db_id='local_mysql', db_type='mysql',
|
||||
models.Connection(
|
||||
conn_id='local_mysql', db_type='mysql',
|
||||
host='localhost', login='airflow', password='airflow',
|
||||
schema='airflow'))
|
||||
session.commit()
|
||||
session.add(
|
||||
models.DatabaseConnection(
|
||||
db_id='mysql_default', db_type='mysql',
|
||||
models.Connection(
|
||||
conn_id='mysql_default', db_type='mysql',
|
||||
host='localhost', login='airflow', password='airflow',
|
||||
schema='airflow'))
|
||||
session.commit()
|
||||
session.add(
|
||||
models.DatabaseConnection(
|
||||
db_id='presto_default', db_type='presto',
|
||||
models.Connection(
|
||||
conn_id='presto_default', db_type='presto',
|
||||
host='localhost',
|
||||
schema='hive', port=10001))
|
||||
session.commit()
|
||||
session.add(
|
||||
models.DatabaseConnection(
|
||||
db_id='hive_default', db_type='hive',
|
||||
models.Connection(
|
||||
conn_id='hive_default', db_type='hive',
|
||||
host='localhost',
|
||||
schema='default', port=10000))
|
||||
session.commit()
|
||||
|
|
|
@ -4,7 +4,7 @@ import subprocess
|
|||
import sys
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from airflow.models import DatabaseConnection
|
||||
from airflow.models import Connection
|
||||
from airflow.configuration import conf
|
||||
from airflow import settings
|
||||
|
||||
|
@ -21,13 +21,13 @@ from airflow.hooks.base_hook import BaseHook
|
|||
|
||||
class HiveHook(BaseHook):
|
||||
def __init__(self,
|
||||
hive_dbid=conf.get('hooks', 'HIVE_DEFAULT_DBID')):
|
||||
hive_conn_id=conf.get('hooks', 'HIVE_DEFAULT_CONN_ID')):
|
||||
session = settings.Session()
|
||||
db = session.query(
|
||||
DatabaseConnection).filter(
|
||||
DatabaseConnection.db_id == hive_dbid)
|
||||
Connection).filter(
|
||||
Connection.conn_id == hive_conn_id)
|
||||
if db.count() == 0:
|
||||
raise Exception("The dbid you provided isn't defined")
|
||||
raise Exception("The conn_id you provided isn't defined")
|
||||
else:
|
||||
db = db.all()[0]
|
||||
self.host = db.host
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import MySQLdb
|
||||
from airflow import settings
|
||||
from airflow.models import DatabaseConnection
|
||||
from airflow.models import Connection
|
||||
|
||||
|
||||
class MySqlHook(object):
|
||||
|
||||
def __init__(
|
||||
self, host=None, login=None, psw=None, db=None, mysql_dbid=None):
|
||||
if not mysql_dbid:
|
||||
self, host=None, login=None,
|
||||
psw=None, db=None, mysql_conn_id=None):
|
||||
if not mysql_conn_id:
|
||||
self.host = host
|
||||
self.login = login
|
||||
self.psw = psw
|
||||
|
@ -15,8 +16,8 @@ class MySqlHook(object):
|
|||
else:
|
||||
session = settings.Session()
|
||||
db = session.query(
|
||||
DatabaseConnection).filter(
|
||||
DatabaseConnection.db_id == mysql_dbid)
|
||||
Connection).filter(
|
||||
Connection.conn_id == mysql_conn_id)
|
||||
if db.count() == 0:
|
||||
raise Exception("The mysql_dbid you provided isn't defined")
|
||||
else:
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import subprocess
|
||||
|
||||
from airflow import settings
|
||||
from airflow.configuration import conf
|
||||
from airflow.models import DatabaseConnection
|
||||
from airflow.models import Connection
|
||||
from airflow.hooks.base_hook import BaseHook
|
||||
from airflow.hooks.presto.presto_client import PrestoClient
|
||||
|
||||
|
@ -13,13 +11,14 @@ class PrestoHook(BaseHook):
|
|||
"""
|
||||
Interact with Presto!
|
||||
"""
|
||||
def __init__(self, presto_dbid=conf.get('hooks', 'PRESTO_DEFAULT_DBID')):
|
||||
def __init__(
|
||||
self, presto_conn_id=conf.get('hooks', 'PRESTO_DEFAULT_CONN_ID')):
|
||||
session = settings.Session()
|
||||
db = session.query(
|
||||
DatabaseConnection).filter(
|
||||
DatabaseConnection.db_id == presto_dbid)
|
||||
Connection).filter(
|
||||
Connection.conn_id == presto_conn_id)
|
||||
if db.count() == 0:
|
||||
raise Exception("The presto_dbid you provided isn't defined")
|
||||
raise Exception("The presto_conn_id you provided isn't defined")
|
||||
else:
|
||||
db = db.all()[0]
|
||||
self.host = db.host
|
||||
|
|
|
@ -2,9 +2,9 @@ from airflow.configuration import conf
|
|||
|
||||
def max_partition(
|
||||
table, schema="default",
|
||||
hive_dbid=conf.get('hooks', 'HIVE_DEFAULT_DBID')):
|
||||
hive_conn_id=conf.get('hooks', 'HIVE_DEFAULT_CONN_ID')):
|
||||
from airflow.hooks.hive_hook import HiveHook
|
||||
if '.' in table:
|
||||
schema, table = table.split('.')
|
||||
hh = HiveHook(hive_dbid=hive_dbid)
|
||||
hh = HiveHook(hive_conn_id=hive_conn_id)
|
||||
return hh.max_partition(schema=schema, table_name=table)
|
||||
|
|
|
@ -130,18 +130,18 @@ class User(Base):
|
|||
return False
|
||||
|
||||
|
||||
class DatabaseConnection(Base):
|
||||
class Connection(Base):
|
||||
"""
|
||||
Placeholder to store information about different database instances
|
||||
connection information. The idea here is that scripts use references to
|
||||
database instances (db_id) instead of hard coding hostname, logins and
|
||||
database instances (conn_id) instead of hard coding hostname, logins and
|
||||
passwords when using operators or hooks.
|
||||
"""
|
||||
__tablename__ = "db_connection"
|
||||
__tablename__ = "connection"
|
||||
|
||||
id = Column(Integer(), primary_key=True)
|
||||
db_id = Column(String(ID_LEN), unique=True)
|
||||
db_type = Column(String(500))
|
||||
conn_id = Column(String(ID_LEN), unique=True)
|
||||
conn_type = Column(String(500))
|
||||
host = Column(String(500))
|
||||
schema = Column(String(500))
|
||||
login = Column(String(500))
|
||||
|
@ -150,11 +150,11 @@ class DatabaseConnection(Base):
|
|||
extra = Column(String(5000))
|
||||
|
||||
def __init__(
|
||||
self, db_id=None, db_type=None,
|
||||
self, conn_id=None, conn_type=None,
|
||||
host=None, login=None, password=None,
|
||||
schema=None, port=None):
|
||||
self.db_id = db_id
|
||||
self.db_type = db_type
|
||||
self.conn_id = conn_id
|
||||
self.conn_type = conn_type
|
||||
self.host = host
|
||||
self.login = login
|
||||
self.password = password
|
||||
|
@ -163,15 +163,15 @@ class DatabaseConnection(Base):
|
|||
|
||||
def get_hook(self):
|
||||
from airflow import hooks
|
||||
if self.db_type == 'mysql':
|
||||
return hooks.MySqlHook(mysql_dbid=self.db_id)
|
||||
elif self.db_type == 'hive':
|
||||
return hooks.HiveHook(hive_dbid=self.db_id)
|
||||
elif self.db_type == 'presto':
|
||||
return hooks.PrestoHook(presto_dbid=self.db_id)
|
||||
if self.conn_type == 'mysql':
|
||||
return hooks.MySqlHook(mysql_conn_id=self.conn_id)
|
||||
elif self.conn_type == 'hive':
|
||||
return hooks.HiveHook(hive_conn_id=self.conn_id)
|
||||
elif self.conn_type == 'presto':
|
||||
return hooks.PrestoHook(presto_conn_id=self.conn_id)
|
||||
|
||||
def __repr__(self):
|
||||
return self.db_id
|
||||
return self.conn_id
|
||||
|
||||
|
||||
class DagPickle(Base):
|
||||
|
@ -1150,7 +1150,7 @@ class Chart(Base):
|
|||
|
||||
id = Column(Integer, primary_key=True)
|
||||
label = Column(String(200))
|
||||
db_id = Column(String(ID_LEN), ForeignKey('db_connection.db_id'))
|
||||
conn_id = Column(String(ID_LEN), ForeignKey('connection.conn_id'))
|
||||
user_id = Column(Integer(), ForeignKey('user.id'),)
|
||||
chart_type = Column(String(100), default="line")
|
||||
sql_layout = Column(String(50), default="series")
|
||||
|
@ -1162,5 +1162,5 @@ class Chart(Base):
|
|||
default_params = Column(String(5000), default="{}")
|
||||
owner = relationship("User", cascade=False, cascade_backrefs=False)
|
||||
x_is_date = Column(Boolean, default=True)
|
||||
db = relationship("DatabaseConnection")
|
||||
db = relationship("Connection")
|
||||
iteration_no = Column(Integer, default=0)
|
||||
|
|
|
@ -12,8 +12,8 @@ class HiveOperator(BaseOperator):
|
|||
|
||||
:param hql: the hql to be executed
|
||||
:type hql: string
|
||||
:param hive_dbid: reference to the Hive database
|
||||
:type hive_dbid: string
|
||||
:param hive_conn_id: reference to the Hive database
|
||||
:type hive_conn_id: string
|
||||
"""
|
||||
|
||||
__mapper_args__ = {
|
||||
|
@ -25,12 +25,12 @@ class HiveOperator(BaseOperator):
|
|||
@apply_defaults
|
||||
def __init__(
|
||||
self, hql,
|
||||
hive_dbid=conf.get('hooks', 'HIVE_DEFAULT_DBID'),
|
||||
hive_conn_id=conf.get('hooks', 'HIVE_DEFAULT_CONN_ID'),
|
||||
*args, **kwargs):
|
||||
super(HiveOperator, self).__init__(*args, **kwargs)
|
||||
|
||||
self.hive_dbid = hive_dbid
|
||||
self.hook = HiveHook(hive_dbid=hive_dbid)
|
||||
self.hive_conn_id = hive_conn_id
|
||||
self.hook = HiveHook(hive_conn_id=hive_conn_id)
|
||||
self.hql = hql
|
||||
|
||||
def execute(self, execution_date):
|
||||
|
|
|
@ -17,15 +17,15 @@ class MySqlOperator(BaseOperator):
|
|||
template_ext = ('.sql',)
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self, sql, mysql_dbid, *args, **kwargs):
|
||||
def __init__(self, sql, mysql_conn_id, *args, **kwargs):
|
||||
"""
|
||||
Parameters:
|
||||
mysql_dbid: reference to a specific mysql database
|
||||
mysql_conn_id: reference to a specific mysql database
|
||||
sql: the sql code you to be executed
|
||||
"""
|
||||
super(MySqlOperator, self).__init__(*args, **kwargs)
|
||||
|
||||
self.hook = MySqlHook(mysql_dbid=mysql_dbid)
|
||||
self.hook = MySqlHook(mysql_conn_id=mysql_conn_id)
|
||||
self.sql = sql
|
||||
|
||||
def execute(self, execution_date):
|
||||
|
|
|
@ -6,7 +6,7 @@ from airflow import settings
|
|||
from airflow.configuration import conf
|
||||
from airflow.hooks import HiveHook
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.models import DatabaseConnection as DB
|
||||
from airflow.models import Connection as DB
|
||||
from airflow.models import State
|
||||
from airflow.models import TaskInstance
|
||||
from airflow.utils import apply_defaults
|
||||
|
@ -56,17 +56,17 @@ class SqlSensor(BaseSensorOperator):
|
|||
}
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self, db_id, sql, *args, **kwargs):
|
||||
def __init__(self, conn_id, sql, *args, **kwargs):
|
||||
|
||||
super(SqlSensor, self).__init__(*args, **kwargs)
|
||||
|
||||
self.sql = sql
|
||||
self.db_id = db_id
|
||||
self.conn_id = conn_id
|
||||
|
||||
session = settings.Session()
|
||||
db = session.query(DB).filter(DB.db_id==db_id).all()
|
||||
db = session.query(DB).filter(DB.conn_id==conn_id).all()
|
||||
if not db:
|
||||
raise Exception("db_id doesn't exist in the repository")
|
||||
raise Exception("conn_id doesn't exist in the repository")
|
||||
self.hook = db[0].get_hook()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
@ -132,14 +132,14 @@ class HivePartitionSensor(BaseSensorOperator):
|
|||
def __init__(
|
||||
self,
|
||||
table, partition="ds='{{ ds }}'",
|
||||
hive_dbid=conf.get('hooks', 'HIVE_DEFAULT_DBID'),
|
||||
hive_conn_id=conf.get('hooks', 'HIVE_DEFAULT_CONN_ID'),
|
||||
schema='default',
|
||||
*args, **kwargs):
|
||||
super(HivePartitionSensor, self).__init__(*args, **kwargs)
|
||||
if '.' in table:
|
||||
schema, table = table.split('.')
|
||||
self.hive_dbid = hive_dbid
|
||||
self.hook = HiveHook(hive_dbid=hive_dbid)
|
||||
self.hive_conn_id = hive_conn_id
|
||||
self.hook = HiveHook(hive_conn_id=hive_conn_id)
|
||||
self.table = table
|
||||
self.partition = partition
|
||||
self.schema = schema
|
||||
|
|
|
@ -134,24 +134,24 @@ class Airflow(BaseView):
|
|||
@wwwutils.gzipped
|
||||
def query(self):
|
||||
session = settings.Session()
|
||||
dbs = session.query(models.DatabaseConnection).order_by(
|
||||
models.DatabaseConnection.db_id)
|
||||
db_choices = [(db.db_id, db.db_id) for db in dbs]
|
||||
db_id_str = request.args.get('db_id')
|
||||
dbs = session.query(models.Connection).order_by(
|
||||
models.Connection.conn_id)
|
||||
db_choices = [(db.conn_id, db.conn_id) for db in dbs]
|
||||
conn_id_str = request.args.get('conn_id')
|
||||
sql = request.args.get('sql')
|
||||
|
||||
class QueryForm(Form):
|
||||
db_id = SelectField("Layout", choices=db_choices)
|
||||
conn_id = SelectField("Layout", choices=db_choices)
|
||||
sql = TextAreaField("SQL", widget=wwwutils.AceEditorWidget())
|
||||
data = {
|
||||
'db_id': db_id_str,
|
||||
'conn_id': conn_id_str,
|
||||
'sql': sql,
|
||||
}
|
||||
results = None
|
||||
has_data = False
|
||||
error = False
|
||||
if db_id_str:
|
||||
db = [db for db in dbs if db.db_id == db_id_str][0]
|
||||
if conn_id_str:
|
||||
db = [db for db in dbs if db.conn_id == conn_id_str][0]
|
||||
hook = db.get_hook()
|
||||
try:
|
||||
# df = hook.get_pandas_df(wwwutils.limit_sql(sql, QUERY_LIMIT))
|
||||
|
@ -193,7 +193,7 @@ class Airflow(BaseView):
|
|||
chart_id = request.args.get('chart_id')
|
||||
chart = session.query(models.Chart).filter_by(id=chart_id).all()[0]
|
||||
db = session.query(
|
||||
models.DatabaseConnection).filter_by(db_id=chart.db_id).all()[0]
|
||||
models.Connection).filter_by(conn_id=chart.conn_id).all()[0]
|
||||
session.expunge_all()
|
||||
|
||||
payload = {}
|
||||
|
@ -1022,8 +1022,8 @@ mv = UserModelView(models.User, Session, name="Users", category="Admin")
|
|||
admin.add_view(mv)
|
||||
|
||||
|
||||
class DatabaseConnectionModelView(LoginMixin, ModelView):
|
||||
column_list = ('db_id', 'db_type', 'host', 'port')
|
||||
class ConnectionModelView(LoginMixin, ModelView):
|
||||
column_list = ('conn_id', 'db_type', 'host', 'port')
|
||||
form_choices = {
|
||||
'db_type': [
|
||||
('hive', 'Hive',),
|
||||
|
@ -1034,9 +1034,9 @@ class DatabaseConnectionModelView(LoginMixin, ModelView):
|
|||
('ftp', 'FTP',),
|
||||
]
|
||||
}
|
||||
mv = DatabaseConnectionModelView(
|
||||
models.DatabaseConnection, Session,
|
||||
name="Database Connections", category="Admin")
|
||||
mv = ConnectionModelView(
|
||||
models.Connection, Session,
|
||||
name="Connections", category="Admin")
|
||||
admin.add_view(mv)
|
||||
|
||||
|
||||
|
@ -1085,12 +1085,12 @@ class ChartModelView(LoginMixin, ModelView):
|
|||
'sql',
|
||||
'default_params',)
|
||||
column_list = (
|
||||
'label', 'db_id', 'chart_type', 'owner',
|
||||
'label', 'conn_id', 'chart_type', 'owner',
|
||||
'show_datatable', 'show_sql',)
|
||||
column_formatters = dict(label=label_link)
|
||||
create_template = 'airflow/chart/create.html'
|
||||
edit_template = 'airflow/chart/edit.html'
|
||||
column_filters = ('owner.username', 'db_id',)
|
||||
column_filters = ('owner.username', 'conn_id',)
|
||||
column_searchable_list = ('owner.username', 'label', 'sql')
|
||||
column_descriptions = {
|
||||
'label': "Can include {{ templated_fields }} and {{ macros }}",
|
||||
|
|
Загрузка…
Ссылка в новой задаче