[AIRFLOW-1275] Put 'airflow pool' into API

Closes #2346 from skudriashev/airflow-1275
This commit is contained in:
Stanislav Kudriashev 2017-06-21 16:36:45 +02:00 коммит произвёл Bolke de Bruin
Родитель a45e2d1888
Коммит 9958aa9d53
14 изменённых файлов: 673 добавлений и 98 удалений

Просмотреть файл

@ -14,17 +14,47 @@
# #
class Client: class Client(object):
"""Base API client for all API clients."""
def __init__(self, api_base_url, auth): def __init__(self, api_base_url, auth):
self._api_base_url = api_base_url self._api_base_url = api_base_url
self._auth = auth self._auth = auth
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
""" """Create a dag run for the specified dag.
Creates a dag run for the specified dag
:param dag_id: :param dag_id:
:param run_id: :param run_id:
:param conf: :param conf:
:param execution_date:
:return: :return:
""" """
raise NotImplementedError() raise NotImplementedError()
def get_pool(self, name):
"""Get pool.
:param name: pool name
"""
raise NotImplementedError()
def get_pools(self):
"""Get all pools."""
raise NotImplementedError()
def create_pool(self, name, slots, description):
"""Create a pool.
:param name: pool name
:param slots: pool slots amount
:param description: pool description
"""
raise NotImplementedError()
def delete_pool(self, name):
"""Delete pool.
:param name: pool name
"""
raise NotImplementedError()

Просмотреть файл

@ -11,30 +11,70 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#
from future.moves.urllib.parse import urljoin from future.moves.urllib.parse import urljoin
import requests
from airflow.api.client import api_client from airflow.api.client import api_client
import requests
class Client(api_client.Client): class Client(api_client.Client):
"""Json API client implementation."""
def _request(self, url, method='GET', json=None):
params = {
'url': url,
'auth': self._auth,
}
if json is not None:
params['json'] = json
resp = getattr(requests, method.lower())(**params)
if not resp.ok:
try:
data = resp.json()
except Exception:
data = {}
raise IOError(data.get('error', 'Server error'))
return resp.json()
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
endpoint = '/api/experimental/dags/{}/dag_runs'.format(dag_id) endpoint = '/api/experimental/dags/{}/dag_runs'.format(dag_id)
url = urljoin(self._api_base_url, endpoint) url = urljoin(self._api_base_url, endpoint)
data = self._request(url, method='POST',
resp = requests.post(url,
auth=self._auth,
json={ json={
"run_id": run_id, "run_id": run_id,
"conf": conf, "conf": conf,
"execution_date": execution_date, "execution_date": execution_date,
}) })
if not resp.ok:
raise IOError()
data = resp.json()
return data['message'] return data['message']
def get_pool(self, name):
endpoint = '/api/experimental/pools/{}'.format(name)
url = urljoin(self._api_base_url, endpoint)
pool = self._request(url)
return pool['pool'], pool['slots'], pool['description']
def get_pools(self):
endpoint = '/api/experimental/pools'
url = urljoin(self._api_base_url, endpoint)
pools = self._request(url)
return [(p['pool'], p['slots'], p['description']) for p in pools]
def create_pool(self, name, slots, description):
endpoint = '/api/experimental/pools'
url = urljoin(self._api_base_url, endpoint)
pool = self._request(url, method='POST',
json={
'name': name,
'slots': slots,
'description': description,
})
return pool['pool'], pool['slots'], pool['description']
def delete_pool(self, name):
endpoint = '/api/experimental/pools/{}'.format(name)
url = urljoin(self._api_base_url, endpoint)
pool = self._request(url, method='DELETE')
return pool['pool'], pool['slots'], pool['description']

Просмотреть файл

@ -11,15 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#
from airflow.api.client import api_client from airflow.api.client import api_client
from airflow.api.common.experimental import pool
from airflow.api.common.experimental import trigger_dag from airflow.api.common.experimental import trigger_dag
class Client(api_client.Client): class Client(api_client.Client):
"""Local API client implementation."""
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
dr = trigger_dag.trigger_dag(dag_id=dag_id, dr = trigger_dag.trigger_dag(dag_id=dag_id,
run_id=run_id, run_id=run_id,
conf=conf, conf=conf,
execution_date=execution_date) execution_date=execution_date)
return "Created {}".format(dr) return "Created {}".format(dr)
def get_pool(self, name):
p = pool.get_pool(name=name)
return p.pool, p.slots, p.description
def get_pools(self):
return [(p.pool, p.slots, p.description) for p in pool.get_pools()]
def create_pool(self, name, slots, description):
p = pool.create_pool(name=name, slots=slots, description=description)
return p.pool, p.slots, p.description
def delete_pool(self, name):
p = pool.delete_pool(name=name)
return p.pool, p.slots, p.description

Просмотреть файл

@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from airflow.exceptions import AirflowException
from airflow.models import Pool
from airflow.utils.db import provide_session
class PoolBadRequest(AirflowException):
status = 400
class PoolNotFound(AirflowException):
status = 404
@provide_session
def get_pool(name, session=None):
"""Get pool by a given name."""
if not (name and name.strip()):
raise PoolBadRequest("Pool name shouldn't be empty")
pool = session.query(Pool).filter_by(pool=name).first()
if pool is None:
raise PoolNotFound("Pool '%s' doesn't exist" % name)
return pool
@provide_session
def get_pools(session=None):
"""Get all pools."""
return session.query(Pool).all()
@provide_session
def create_pool(name, slots, description, session=None):
"""Create a pool with a given parameters."""
if not (name and name.strip()):
raise PoolBadRequest("Pool name shouldn't be empty")
try:
slots = int(slots)
except ValueError:
raise PoolBadRequest("Bad value for `slots`: %s" % slots)
session.expire_on_commit = False
pool = session.query(Pool).filter_by(pool=name).first()
if pool is None:
pool = Pool(pool=name, slots=slots, description=description)
session.add(pool)
else:
pool.slots = slots
pool.description = description
session.commit()
return pool
@provide_session
def delete_pool(name, session=None):
"""Delete pool by a given name."""
if not (name and name.strip()):
raise PoolBadRequest("Pool name shouldn't be empty")
pool = session.query(Pool).filter_by(pool=name).first()
if pool is None:
raise PoolNotFound("Pool '%s' doesn't exist" % name)
session.delete(pool)
session.commit()
return pool

Просмотреть файл

@ -49,7 +49,7 @@ from airflow.exceptions import AirflowException
from airflow.executors import GetDefaultExecutor from airflow.executors import GetDefaultExecutor
from airflow.models import (DagModel, DagBag, TaskInstance, from airflow.models import (DagModel, DagBag, TaskInstance,
DagPickle, DagRun, Variable, DagStat, DagPickle, DagRun, Variable, DagStat,
Pool, Connection) Connection)
from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS) from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
from airflow.utils import db as db_utils from airflow.utils import db as db_utils
from airflow.utils import logging as logging_utils from airflow.utils import logging as logging_utils
@ -187,40 +187,28 @@ def trigger_dag(args):
def pool(args): def pool(args):
session = settings.Session() def _tabulate(pools):
if args.get or (args.set and args.set[0]) or args.delete: return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'],
name = args.get or args.delete or args.set[0] tablefmt="fancy_grid")
pool = (
session.query(Pool) try:
.filter(Pool.pool == name) if args.get is not None:
.first()) pools = [api_client.get_pool(name=args.get)]
if pool and args.get: elif args.set:
print("{} ".format(pool)) pools = [api_client.create_pool(name=args.set[0],
return slots=args.set[1],
elif not pool and (args.get or args.delete): description=args.set[2])]
print("No pool named {} found".format(name)) elif args.delete:
elif not pool and args.set: pools = [api_client.delete_pool(name=args.delete)]
pool = Pool( else:
pool=name, pools = api_client.get_pools()
slots=args.set[1], except (AirflowException, IOError) as err:
description=args.set[2]) logging.error(err)
session.add(pool) else:
session.commit() logging.info(_tabulate(pools=pools))
print("{} ".format(pool))
elif pool and args.set:
pool.slots = args.set[1]
pool.description = args.set[2]
session.commit()
print("{} ".format(pool))
return
elif pool and args.delete:
session.query(Pool).filter_by(pool=args.delete).delete()
session.commit()
print("Pool {} deleted".format(name))
def variables(args): def variables(args):
if args.get: if args.get:
try: try:
var = Variable.get(args.get, var = Variable.get(args.get,

Просмотреть файл

@ -4395,6 +4395,14 @@ class Pool(Base):
def __repr__(self): def __repr__(self):
return self.pool return self.pool
def to_json(self):
return {
'id': self.id,
'pool': self.pool,
'slots': self.slots,
'description': self.description,
}
@provide_session @provide_session
def used_slots(self, session): def used_slots(self, session):
""" """

Просмотреть файл

@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import airflow.api import airflow.api
from airflow.api.common.experimental import pool as pool_api
from airflow.api.common.experimental import trigger_dag as trigger from airflow.api.common.experimental import trigger_dag as trigger
from airflow.api.common.experimental.get_task import get_task from airflow.api.common.experimental.get_task import get_task
from airflow.api.common.experimental.get_task_instance import get_task_instance from airflow.api.common.experimental.get_task_instance import get_task_instance
@ -96,7 +98,6 @@ def test():
@requires_authentication @requires_authentication
def task_info(dag_id, task_id): def task_info(dag_id, task_id):
"""Returns a JSON with a task's public instance variables. """ """Returns a JSON with a task's public instance variables. """
try: try:
info = get_task(dag_id, task_id) info = get_task(dag_id, task_id)
except AirflowException as err: except AirflowException as err:
@ -169,4 +170,67 @@ def latest_dag_runs():
'dag_run_url': url_for('airflow.graph', dag_id=dagrun.dag_id, 'dag_run_url': url_for('airflow.graph', dag_id=dagrun.dag_id,
execution_date=dagrun.execution_date) execution_date=dagrun.execution_date)
}) })
return jsonify(items=payload) # old flask versions dont support jsonifying arrays return jsonify(items=payload) # old flask versions dont support jsonifying arrays
@api_experimental.route('/pools/<string:name>', methods=['GET'])
@requires_authentication
def get_pool(name):
"""Get pool by a given name."""
try:
pool = pool_api.get_pool(name=name)
except AirflowException as e:
_log.error(e)
response = jsonify(error="{}".format(e))
response.status_code = getattr(e, 'status', 500)
return response
else:
return jsonify(pool.to_json())
@api_experimental.route('/pools', methods=['GET'])
@requires_authentication
def get_pools():
"""Get all pools."""
try:
pools = pool_api.get_pools()
except AirflowException as e:
_log.error(e)
response = jsonify(error="{}".format(e))
response.status_code = getattr(e, 'status', 500)
return response
else:
return jsonify([p.to_json() for p in pools])
@csrf.exempt
@api_experimental.route('/pools', methods=['POST'])
@requires_authentication
def create_pool():
"""Create a pool."""
params = request.get_json(force=True)
try:
pool = pool_api.create_pool(**params)
except AirflowException as e:
_log.error(e)
response = jsonify(error="{}".format(e))
response.status_code = getattr(e, 'status', 500)
return response
else:
return jsonify(pool.to_json())
@csrf.exempt
@api_experimental.route('/pools/<string:name>', methods=['DELETE'])
@requires_authentication
def delete_pool(name):
"""Delete pool."""
try:
pool = pool_api.delete_pool(name=name)
except AirflowException as e:
_log.error(e)
response = jsonify(error="{}".format(e))
response.status_code = getattr(e, 'status', 500)
return response
else:
return jsonify(pool.to_json())

Просмотреть файл

@ -11,9 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from .client import *
from .common import *

Просмотреть файл

@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime
import json import json
import unittest import unittest
import datetime
from mock import patch from mock import patch
from airflow import AirflowException from airflow import AirflowException
from airflow import models
from airflow.api.client.local_client import Client from airflow.api.client.local_client import Client
from airflow import models
from airflow import settings
from airflow.utils.state import State from airflow.utils.state import State
EXECDATE = datetime.datetime.now() EXECDATE = datetime.datetime.now()
@ -53,8 +53,25 @@ def mock_datetime_now(target, dt):
class TestLocalClient(unittest.TestCase): class TestLocalClient(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(TestLocalClient, cls).setUpClass()
session = settings.Session()
session.query(models.Pool).delete()
session.commit()
session.close()
def setUp(self): def setUp(self):
super(TestLocalClient, self).setUp()
self.client = Client(api_base_url=None, auth=None) self.client = Client(api_base_url=None, auth=None)
self.session = settings.Session()
def tearDown(self):
self.session.query(models.Pool).delete()
self.session.commit()
self.session.close()
super(TestLocalClient, self).tearDown()
@patch.object(models.DAG, 'create_dagrun') @patch.object(models.DAG, 'create_dagrun')
def test_trigger_dag(self, mock): def test_trigger_dag(self, mock):
@ -104,4 +121,24 @@ class TestLocalClient(unittest.TestCase):
external_trigger=True) external_trigger=True)
mock.reset_mock() mock.reset_mock()
# this is a unit test only, cannot verify existing dag run def test_get_pool(self):
self.client.create_pool(name='foo', slots=1, description='')
pool = self.client.get_pool(name='foo')
self.assertEqual(pool, ('foo', 1, ''))
def test_get_pools(self):
self.client.create_pool(name='foo1', slots=1, description='')
self.client.create_pool(name='foo2', slots=2, description='')
pools = sorted(self.client.get_pools(), key=lambda p: p[0])
self.assertEqual(pools, [('foo1', 1, ''), ('foo2', 2, '')])
def test_create_pool(self):
pool = self.client.create_pool(name='foo', slots=1, description='')
self.assertEqual(pool, ('foo', 1, ''))
self.assertEqual(self.session.query(models.Pool).count(), 1)
def test_delete_pool(self):
self.client.create_pool(name='foo', slots=1, description='')
self.assertEqual(self.session.query(models.Pool).count(), 1)
self.client.delete_pool(name='foo')
self.assertEqual(self.session.query(models.Pool).count(), 0)

Просмотреть файл

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Просмотреть файл

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#
import unittest import unittest
@ -27,6 +26,7 @@ DEV_NULL = "/dev/null"
class TestMarkTasks(unittest.TestCase): class TestMarkTasks(unittest.TestCase):
def setUp(self): def setUp(self):
self.dagbag = models.DagBag(include_examples=True) self.dagbag = models.DagBag(include_examples=True)
self.dag1 = self.dagbag.dags['test_example_bash_operator'] self.dag1 = self.dagbag.dags['test_example_bash_operator']
@ -52,6 +52,16 @@ class TestMarkTasks(unittest.TestCase):
self.session = Session() self.session = Session()
def tearDown(self):
self.dag1.clear()
self.dag2.clear()
# just to make sure we are fully cleaned up
self.session.query(models.DagRun).delete()
self.session.query(models.TaskInstance).delete()
self.session.commit()
self.session.close()
def snapshot_state(self, dag, execution_dates): def snapshot_state(self, dag, execution_dates):
TI = models.TaskInstance TI = models.TaskInstance
tis = self.session.query(TI).filter( tis = self.session.query(TI).filter(
@ -197,16 +207,6 @@ class TestMarkTasks(unittest.TestCase):
self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], self.verify_state(self.dag2, task_ids, [self.execution_dates[0]],
State.SUCCESS, []) State.SUCCESS, [])
def tearDown(self):
self.dag1.clear()
self.dag2.clear()
# just to make sure we are fully cleaned up
self.session.query(models.DagRun).delete()
self.session.query(models.TaskInstance).delete()
self.session.commit()
self.session.close()
class TestMarkDAGRun(unittest.TestCase): class TestMarkDAGRun(unittest.TestCase):
def setUp(self): def setUp(self):

Просмотреть файл

@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from airflow.api.common.experimental import pool as pool_api
from airflow import models
from airflow import settings
class TestPool(unittest.TestCase):
def setUp(self):
super(TestPool, self).setUp()
self.session = settings.Session()
self.pools = []
for i in range(2):
name = 'experimental_%s' % (i + 1)
pool = models.Pool(
pool=name,
slots=i,
description=name,
)
self.session.add(pool)
self.pools.append(pool)
self.session.commit()
def tearDown(self):
self.session.query(models.Pool).delete()
self.session.commit()
self.session.close()
super(TestPool, self).tearDown()
def test_get_pool(self):
pool = pool_api.get_pool(name=self.pools[0].pool, session=self.session)
self.assertEqual(pool.pool, self.pools[0].pool)
def test_get_pool_non_existing(self):
self.assertRaisesRegexp(pool_api.PoolNotFound,
"^Pool 'test' doesn't exist$",
pool_api.get_pool,
name='test',
session=self.session)
def test_get_pool_bad_name(self):
for name in ('', ' '):
self.assertRaisesRegexp(pool_api.PoolBadRequest,
"^Pool name shouldn't be empty$",
pool_api.get_pool,
name=name,
session=self.session)
def test_get_pools(self):
pools = sorted(pool_api.get_pools(session=self.session),
key=lambda p: p.pool)
self.assertEqual(pools[0].pool, self.pools[0].pool)
self.assertEqual(pools[1].pool, self.pools[1].pool)
def test_create_pool(self):
pool = pool_api.create_pool(name='foo',
slots=5,
description='',
session=self.session)
self.assertEqual(pool.pool, 'foo')
self.assertEqual(pool.slots, 5)
self.assertEqual(pool.description, '')
self.assertEqual(self.session.query(models.Pool).count(), 3)
def test_create_pool_existing(self):
pool = pool_api.create_pool(name=self.pools[0].pool,
slots=5,
description='',
session=self.session)
self.assertEqual(pool.pool, self.pools[0].pool)
self.assertEqual(pool.slots, 5)
self.assertEqual(pool.description, '')
self.assertEqual(self.session.query(models.Pool).count(), 2)
def test_create_pool_bad_name(self):
for name in ('', ' '):
self.assertRaisesRegexp(pool_api.PoolBadRequest,
"^Pool name shouldn't be empty$",
pool_api.create_pool,
name=name,
slots=5,
description='',
session=self.session)
def test_create_pool_bad_slots(self):
self.assertRaisesRegexp(pool_api.PoolBadRequest,
"^Bad value for `slots`: foo$",
pool_api.create_pool,
name='foo',
slots='foo',
description='',
session=self.session)
def test_delete_pool(self):
pool = pool_api.delete_pool(name=self.pools[0].pool,
session=self.session)
self.assertEqual(pool.pool, self.pools[0].pool)
self.assertEqual(self.session.query(models.Pool).count(), 1)
def test_delete_pool_non_existing(self):
self.assertRaisesRegexp(pool_api.PoolNotFound,
"^Pool 'test' doesn't exist$",
pool_api.delete_pool,
name='test',
session=self.session)
def test_delete_pool_bad_name(self):
for name in ('', ' '):
self.assertRaisesRegexp(pool_api.PoolBadRequest,
"^Pool name shouldn't be empty$",
pool_api.delete_pool,
name=name,
session=self.session)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -1062,12 +1062,34 @@ class CoreTest(unittest.TestCase):
class CliTests(unittest.TestCase): class CliTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(CliTests, cls).setUpClass()
cls._cleanup()
def setUp(self): def setUp(self):
super(CliTests, self).setUp()
configuration.load_test_config() configuration.load_test_config()
app = application.create_app() app = application.create_app()
app.config['TESTING'] = True app.config['TESTING'] = True
self.parser = cli.CLIFactory.get_parser() self.parser = cli.CLIFactory.get_parser()
self.dagbag = models.DagBag(dag_folder=DEV_NULL, include_examples=True) self.dagbag = models.DagBag(dag_folder=DEV_NULL, include_examples=True)
self.session = Session()
def tearDown(self):
self._cleanup(session=self.session)
super(CliTests, self).tearDown()
@staticmethod
def _cleanup(session=None):
if session is None:
session = Session()
session.query(models.Pool).delete()
session.query(models.Variable).delete()
session.commit()
session.close()
def test_cli_list_dags(self): def test_cli_list_dags(self):
args = self.parser.parse_args(['list_dags', '--report']) args = self.parser.parse_args(['list_dags', '--report'])
@ -1100,8 +1122,8 @@ class CliTests(unittest.TestCase):
cli.connections(self.parser.parse_args(['connections', '--list'])) cli.connections(self.parser.parse_args(['connections', '--list']))
stdout = mock_stdout.getvalue() stdout = mock_stdout.getvalue()
conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]] conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]]
for ii, line in enumerate(stdout.split('\n')) for ii, line in enumerate(stdout.split('\n'))
if ii % 2 == 1] if ii % 2 == 1]
conns = [conn for conn in conns if len(conn) > 0] conns = [conn for conn in conns if len(conn) > 0]
# Assert that some of the connections are present in the output as # Assert that some of the connections are present in the output as
@ -1365,14 +1387,27 @@ class CliTests(unittest.TestCase):
'-c', 'NOT JSON']) '-c', 'NOT JSON'])
) )
def test_pool(self): def test_pool_create(self):
# Checks if all subcommands are properly received cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
cli.pool(self.parser.parse_args([ self.assertEqual(self.session.query(models.Pool).count(), 1)
'pool', '-s', 'foo', '1', '"my foo pool"']))
cli.pool(self.parser.parse_args([ def test_pool_get(self):
'pool', '-g', 'foo'])) cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
cli.pool(self.parser.parse_args([ try:
'pool', '-x', 'foo'])) cli.pool(self.parser.parse_args(['pool', '-g', 'foo']))
except Exception as e:
self.fail("The 'pool -g foo' command raised unexpectedly: %s" % e)
def test_pool_delete(self):
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
cli.pool(self.parser.parse_args(['pool', '-x', 'foo']))
self.assertEqual(self.session.query(models.Pool).count(), 0)
def test_pool_no_args(self):
try:
cli.pool(self.parser.parse_args(['pool']))
except Exception as e:
self.fail("The 'pool' command raised unexpectedly: %s" % e)
def test_variables(self): def test_variables(self):
# Checks if all subcommands are properly received # Checks if all subcommands are properly received
@ -1426,10 +1461,6 @@ class CliTests(unittest.TestCase):
self.assertEqual('original', models.Variable.get('bar')) self.assertEqual('original', models.Variable.get('bar'))
self.assertEqual('{"foo": "bar"}', models.Variable.get('foo')) self.assertEqual('{"foo": "bar"}', models.Variable.get('foo'))
session = settings.Session()
session.query(Variable).delete()
session.commit()
session.close()
os.remove('variables1.json') os.remove('variables1.json')
os.remove('variables2.json') os.remove('variables2.json')

Просмотреть файл

@ -19,23 +19,36 @@ from urllib.parse import quote_plus
from airflow import configuration from airflow import configuration
from airflow.api.common.experimental.trigger_dag import trigger_dag from airflow.api.common.experimental.trigger_dag import trigger_dag
from airflow.models import DagBag, DagRun, TaskInstance from airflow.models import DagBag, DagRun, Pool, TaskInstance
from airflow.settings import Session from airflow.settings import Session
from airflow.www import app as application from airflow.www import app as application
class ApiExperimentalTests(unittest.TestCase): class TestApiExperimental(unittest.TestCase):
def setUp(self): @classmethod
configuration.load_test_config() def setUpClass(cls):
app = application.create_app(testing=True) super(TestApiExperimental, cls).setUpClass()
self.app = app.test_client()
session = Session() session = Session()
session.query(DagRun).delete() session.query(DagRun).delete()
session.query(TaskInstance).delete() session.query(TaskInstance).delete()
session.commit() session.commit()
session.close() session.close()
def setUp(self):
super(TestApiExperimental, self).setUp()
configuration.load_test_config()
app = application.create_app(testing=True)
self.app = app.test_client()
def tearDown(self):
session = Session()
session.query(DagRun).delete()
session.query(TaskInstance).delete()
session.commit()
session.close()
super(TestApiExperimental, self).tearDown()
def test_task_info(self): def test_task_info(self):
url_template = '/api/experimental/dags/{}/tasks/{}' url_template = '/api/experimental/dags/{}/tasks/{}'
@ -62,7 +75,7 @@ class ApiExperimentalTests(unittest.TestCase):
url_template = '/api/experimental/dags/{}/dag_runs' url_template = '/api/experimental/dags/{}/dag_runs'
response = self.app.post( response = self.app.post(
url_template.format('example_bash_operator'), url_template.format('example_bash_operator'),
data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())), data=json.dumps({'run_id': 'my_run' + datetime.now().isoformat()}),
content_type="application/json" content_type="application/json"
) )
@ -70,7 +83,7 @@ class ApiExperimentalTests(unittest.TestCase):
response = self.app.post( response = self.app.post(
url_template.format('does_not_exist_dag'), url_template.format('does_not_exist_dag'),
data=json.dumps(dict()), data=json.dumps({}),
content_type="application/json" content_type="application/json"
) )
self.assertEqual(404, response.status_code) self.assertEqual(404, response.status_code)
@ -88,7 +101,7 @@ class ApiExperimentalTests(unittest.TestCase):
# Test Correct execution # Test Correct execution
response = self.app.post( response = self.app.post(
url_template.format(dag_id), url_template.format(dag_id),
data=json.dumps(dict(execution_date=datetime_string)), data=json.dumps({'execution_date': datetime_string}),
content_type="application/json" content_type="application/json"
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
@ -103,7 +116,7 @@ class ApiExperimentalTests(unittest.TestCase):
# Test error for nonexistent dag # Test error for nonexistent dag
response = self.app.post( response = self.app.post(
url_template.format('does_not_exist_dag'), url_template.format('does_not_exist_dag'),
data=json.dumps(dict(execution_date=execution_date.isoformat())), data=json.dumps({'execution_date': execution_date.isoformat()}),
content_type="application/json" content_type="application/json"
) )
self.assertEqual(404, response.status_code) self.assertEqual(404, response.status_code)
@ -111,7 +124,7 @@ class ApiExperimentalTests(unittest.TestCase):
# Test error for bad datetime format # Test error for bad datetime format
response = self.app.post( response = self.app.post(
url_template.format(dag_id), url_template.format(dag_id),
data=json.dumps(dict(execution_date='not_a_datetime')), data=json.dumps({'execution_date': 'not_a_datetime'}),
content_type="application/json" content_type="application/json"
) )
self.assertEqual(400, response.status_code) self.assertEqual(400, response.status_code)
@ -122,7 +135,9 @@ class ApiExperimentalTests(unittest.TestCase):
task_id = 'also_run_this' task_id = 'also_run_this'
execution_date = datetime.now().replace(microsecond=0) execution_date = datetime.now().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat()) datetime_string = quote_plus(execution_date.isoformat())
wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat()) wrong_datetime_string = quote_plus(
datetime(1990, 1, 1, 1, 1, 1).isoformat()
)
# Create DagRun # Create DagRun
trigger_dag(dag_id=dag_id, trigger_dag(dag_id=dag_id,
@ -139,7 +154,8 @@ class ApiExperimentalTests(unittest.TestCase):
# Test error for nonexistent dag # Test error for nonexistent dag
response = self.app.get( response = self.app.get(
url_template.format('does_not_exist_dag', datetime_string, task_id), url_template.format('does_not_exist_dag', datetime_string,
task_id),
) )
self.assertEqual(404, response.status_code) self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8')) self.assertIn('error', response.data.decode('utf-8'))
@ -164,3 +180,122 @@ class ApiExperimentalTests(unittest.TestCase):
) )
self.assertEqual(400, response.status_code) self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8')) self.assertIn('error', response.data.decode('utf-8'))
class TestPoolApiExperimental(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(TestPoolApiExperimental, cls).setUpClass()
session = Session()
session.query(Pool).delete()
session.commit()
session.close()
def setUp(self):
super(TestPoolApiExperimental, self).setUp()
configuration.load_test_config()
app = application.create_app(testing=True)
self.app = app.test_client()
self.session = Session()
self.pools = []
for i in range(2):
name = 'experimental_%s' % (i + 1)
pool = Pool(
pool=name,
slots=i,
description=name,
)
self.session.add(pool)
self.pools.append(pool)
self.session.commit()
self.pool = self.pools[0]
def tearDown(self):
self.session.query(Pool).delete()
self.session.commit()
self.session.close()
super(TestPoolApiExperimental, self).tearDown()
def _get_pool_count(self):
response = self.app.get('/api/experimental/pools')
self.assertEqual(response.status_code, 200)
return len(json.loads(response.data.decode('utf-8')))
def test_get_pool(self):
response = self.app.get(
'/api/experimental/pools/{}'.format(self.pool.pool),
)
self.assertEqual(response.status_code, 200)
self.assertEqual(json.loads(response.data.decode('utf-8')),
self.pool.to_json())
def test_get_pool_non_existing(self):
response = self.app.get('/api/experimental/pools/foo')
self.assertEqual(response.status_code, 404)
self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
"Pool 'foo' doesn't exist")
def test_get_pools(self):
response = self.app.get('/api/experimental/pools')
self.assertEqual(response.status_code, 200)
pools = json.loads(response.data.decode('utf-8'))
self.assertEqual(len(pools), 2)
for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])):
self.assertDictEqual(pool, self.pools[i].to_json())
def test_create_pool(self):
response = self.app.post(
'/api/experimental/pools',
data=json.dumps({
'name': 'foo',
'slots': 1,
'description': '',
}),
content_type='application/json',
)
self.assertEqual(response.status_code, 200)
pool = json.loads(response.data.decode('utf-8'))
self.assertEqual(pool['pool'], 'foo')
self.assertEqual(pool['slots'], 1)
self.assertEqual(pool['description'], '')
self.assertEqual(self._get_pool_count(), 3)
def test_create_pool_with_bad_name(self):
for name in ('', ' '):
response = self.app.post(
'/api/experimental/pools',
data=json.dumps({
'name': name,
'slots': 1,
'description': '',
}),
content_type='application/json',
)
self.assertEqual(response.status_code, 400)
self.assertEqual(
json.loads(response.data.decode('utf-8'))['error'],
"Pool name shouldn't be empty",
)
self.assertEqual(self._get_pool_count(), 2)
def test_delete_pool(self):
response = self.app.delete(
'/api/experimental/pools/{}'.format(self.pool.pool),
)
self.assertEqual(response.status_code, 200)
self.assertEqual(json.loads(response.data.decode('utf-8')),
self.pool.to_json())
self.assertEqual(self._get_pool_count(), 1)
def test_delete_pool_non_existing(self):
response = self.app.delete(
'/api/experimental/pools/foo',
)
self.assertEqual(response.status_code, 404)
self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
"Pool 'foo' doesn't exist")
if __name__ == '__main__':
unittest.main()