[AIRFLOW-1275] Put 'airflow pool' into API

Closes #2346 from skudriashev/airflow-1275
This commit is contained in:
Stanislav Kudriashev 2017-06-21 16:36:45 +02:00 коммит произвёл Bolke de Bruin
Родитель a45e2d1888
Коммит 9958aa9d53
14 изменённых файлов: 673 добавлений и 98 удалений

Просмотреть файл

@ -14,17 +14,47 @@
#
class Client:
class Client(object):
"""Base API client for all API clients."""
def __init__(self, api_base_url, auth):
self._api_base_url = api_base_url
self._auth = auth
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
"""
Creates a dag run for the specified dag
"""Create a dag run for the specified dag.
:param dag_id:
:param run_id:
:param conf:
:param execution_date:
:return:
"""
raise NotImplementedError()
def get_pool(self, name):
"""Get pool.
:param name: pool name
"""
raise NotImplementedError()
def get_pools(self):
"""Get all pools."""
raise NotImplementedError()
def create_pool(self, name, slots, description):
"""Create a pool.
:param name: pool name
:param slots: pool slots amount
:param description: pool description
"""
raise NotImplementedError()
def delete_pool(self, name):
"""Delete pool.
:param name: pool name
"""
raise NotImplementedError()

Просмотреть файл

@ -11,30 +11,70 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from future.moves.urllib.parse import urljoin
import requests
from airflow.api.client import api_client
import requests
class Client(api_client.Client):
"""Json API client implementation."""
def _request(self, url, method='GET', json=None):
params = {
'url': url,
'auth': self._auth,
}
if json is not None:
params['json'] = json
resp = getattr(requests, method.lower())(**params)
if not resp.ok:
try:
data = resp.json()
except Exception:
data = {}
raise IOError(data.get('error', 'Server error'))
return resp.json()
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
endpoint = '/api/experimental/dags/{}/dag_runs'.format(dag_id)
url = urljoin(self._api_base_url, endpoint)
resp = requests.post(url,
auth=self._auth,
data = self._request(url, method='POST',
json={
"run_id": run_id,
"conf": conf,
"execution_date": execution_date,
})
if not resp.ok:
raise IOError()
data = resp.json()
return data['message']
def get_pool(self, name):
endpoint = '/api/experimental/pools/{}'.format(name)
url = urljoin(self._api_base_url, endpoint)
pool = self._request(url)
return pool['pool'], pool['slots'], pool['description']
def get_pools(self):
endpoint = '/api/experimental/pools'
url = urljoin(self._api_base_url, endpoint)
pools = self._request(url)
return [(p['pool'], p['slots'], p['description']) for p in pools]
def create_pool(self, name, slots, description):
endpoint = '/api/experimental/pools'
url = urljoin(self._api_base_url, endpoint)
pool = self._request(url, method='POST',
json={
'name': name,
'slots': slots,
'description': description,
})
return pool['pool'], pool['slots'], pool['description']
def delete_pool(self, name):
endpoint = '/api/experimental/pools/{}'.format(name)
url = urljoin(self._api_base_url, endpoint)
pool = self._request(url, method='DELETE')
return pool['pool'], pool['slots'], pool['description']

Просмотреть файл

@ -11,15 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from airflow.api.client import api_client
from airflow.api.common.experimental import pool
from airflow.api.common.experimental import trigger_dag
class Client(api_client.Client):
"""Local API client implementation."""
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
dr = trigger_dag.trigger_dag(dag_id=dag_id,
run_id=run_id,
conf=conf,
execution_date=execution_date)
return "Created {}".format(dr)
def get_pool(self, name):
p = pool.get_pool(name=name)
return p.pool, p.slots, p.description
def get_pools(self):
return [(p.pool, p.slots, p.description) for p in pool.get_pools()]
def create_pool(self, name, slots, description):
p = pool.create_pool(name=name, slots=slots, description=description)
return p.pool, p.slots, p.description
def delete_pool(self, name):
p = pool.delete_pool(name=name)
return p.pool, p.slots, p.description

Просмотреть файл

@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from airflow.exceptions import AirflowException
from airflow.models import Pool
from airflow.utils.db import provide_session
class PoolBadRequest(AirflowException):
status = 400
class PoolNotFound(AirflowException):
status = 404
@provide_session
def get_pool(name, session=None):
"""Get pool by a given name."""
if not (name and name.strip()):
raise PoolBadRequest("Pool name shouldn't be empty")
pool = session.query(Pool).filter_by(pool=name).first()
if pool is None:
raise PoolNotFound("Pool '%s' doesn't exist" % name)
return pool
@provide_session
def get_pools(session=None):
"""Get all pools."""
return session.query(Pool).all()
@provide_session
def create_pool(name, slots, description, session=None):
"""Create a pool with a given parameters."""
if not (name and name.strip()):
raise PoolBadRequest("Pool name shouldn't be empty")
try:
slots = int(slots)
except ValueError:
raise PoolBadRequest("Bad value for `slots`: %s" % slots)
session.expire_on_commit = False
pool = session.query(Pool).filter_by(pool=name).first()
if pool is None:
pool = Pool(pool=name, slots=slots, description=description)
session.add(pool)
else:
pool.slots = slots
pool.description = description
session.commit()
return pool
@provide_session
def delete_pool(name, session=None):
"""Delete pool by a given name."""
if not (name and name.strip()):
raise PoolBadRequest("Pool name shouldn't be empty")
pool = session.query(Pool).filter_by(pool=name).first()
if pool is None:
raise PoolNotFound("Pool '%s' doesn't exist" % name)
session.delete(pool)
session.commit()
return pool

Просмотреть файл

@ -49,7 +49,7 @@ from airflow.exceptions import AirflowException
from airflow.executors import GetDefaultExecutor
from airflow.models import (DagModel, DagBag, TaskInstance,
DagPickle, DagRun, Variable, DagStat,
Pool, Connection)
Connection)
from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
from airflow.utils import db as db_utils
from airflow.utils import logging as logging_utils
@ -187,40 +187,28 @@ def trigger_dag(args):
def pool(args):
session = settings.Session()
if args.get or (args.set and args.set[0]) or args.delete:
name = args.get or args.delete or args.set[0]
pool = (
session.query(Pool)
.filter(Pool.pool == name)
.first())
if pool and args.get:
print("{} ".format(pool))
return
elif not pool and (args.get or args.delete):
print("No pool named {} found".format(name))
elif not pool and args.set:
pool = Pool(
pool=name,
slots=args.set[1],
description=args.set[2])
session.add(pool)
session.commit()
print("{} ".format(pool))
elif pool and args.set:
pool.slots = args.set[1]
pool.description = args.set[2]
session.commit()
print("{} ".format(pool))
return
elif pool and args.delete:
session.query(Pool).filter_by(pool=args.delete).delete()
session.commit()
print("Pool {} deleted".format(name))
def _tabulate(pools):
return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'],
tablefmt="fancy_grid")
try:
if args.get is not None:
pools = [api_client.get_pool(name=args.get)]
elif args.set:
pools = [api_client.create_pool(name=args.set[0],
slots=args.set[1],
description=args.set[2])]
elif args.delete:
pools = [api_client.delete_pool(name=args.delete)]
else:
pools = api_client.get_pools()
except (AirflowException, IOError) as err:
logging.error(err)
else:
logging.info(_tabulate(pools=pools))
def variables(args):
if args.get:
try:
var = Variable.get(args.get,

Просмотреть файл

@ -4395,6 +4395,14 @@ class Pool(Base):
def __repr__(self):
return self.pool
def to_json(self):
return {
'id': self.id,
'pool': self.pool,
'slots': self.slots,
'description': self.description,
}
@provide_session
def used_slots(self, session):
"""

Просмотреть файл

@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import airflow.api
from airflow.api.common.experimental import pool as pool_api
from airflow.api.common.experimental import trigger_dag as trigger
from airflow.api.common.experimental.get_task import get_task
from airflow.api.common.experimental.get_task_instance import get_task_instance
@ -96,7 +98,6 @@ def test():
@requires_authentication
def task_info(dag_id, task_id):
"""Returns a JSON with a task's public instance variables. """
try:
info = get_task(dag_id, task_id)
except AirflowException as err:
@ -169,4 +170,67 @@ def latest_dag_runs():
'dag_run_url': url_for('airflow.graph', dag_id=dagrun.dag_id,
execution_date=dagrun.execution_date)
})
return jsonify(items=payload) # old flask versions dont support jsonifying arrays
return jsonify(items=payload) # old flask versions dont support jsonifying arrays
@api_experimental.route('/pools/<string:name>', methods=['GET'])
@requires_authentication
def get_pool(name):
"""Get pool by a given name."""
try:
pool = pool_api.get_pool(name=name)
except AirflowException as e:
_log.error(e)
response = jsonify(error="{}".format(e))
response.status_code = getattr(e, 'status', 500)
return response
else:
return jsonify(pool.to_json())
@api_experimental.route('/pools', methods=['GET'])
@requires_authentication
def get_pools():
"""Get all pools."""
try:
pools = pool_api.get_pools()
except AirflowException as e:
_log.error(e)
response = jsonify(error="{}".format(e))
response.status_code = getattr(e, 'status', 500)
return response
else:
return jsonify([p.to_json() for p in pools])
@csrf.exempt
@api_experimental.route('/pools', methods=['POST'])
@requires_authentication
def create_pool():
"""Create a pool."""
params = request.get_json(force=True)
try:
pool = pool_api.create_pool(**params)
except AirflowException as e:
_log.error(e)
response = jsonify(error="{}".format(e))
response.status_code = getattr(e, 'status', 500)
return response
else:
return jsonify(pool.to_json())
@csrf.exempt
@api_experimental.route('/pools/<string:name>', methods=['DELETE'])
@requires_authentication
def delete_pool(name):
"""Delete pool."""
try:
pool = pool_api.delete_pool(name=name)
except AirflowException as e:
_log.error(e)
response = jsonify(error="{}".format(e))
response.status_code = getattr(e, 'status', 500)
return response
else:
return jsonify(pool.to_json())

Просмотреть файл

@ -11,9 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .client import *
from .common import *

Просмотреть файл

@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import json
import unittest
import datetime
from mock import patch
from airflow import AirflowException
from airflow import models
from airflow.api.client.local_client import Client
from airflow import models
from airflow import settings
from airflow.utils.state import State
EXECDATE = datetime.datetime.now()
@ -53,8 +53,25 @@ def mock_datetime_now(target, dt):
class TestLocalClient(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(TestLocalClient, cls).setUpClass()
session = settings.Session()
session.query(models.Pool).delete()
session.commit()
session.close()
def setUp(self):
super(TestLocalClient, self).setUp()
self.client = Client(api_base_url=None, auth=None)
self.session = settings.Session()
def tearDown(self):
self.session.query(models.Pool).delete()
self.session.commit()
self.session.close()
super(TestLocalClient, self).tearDown()
@patch.object(models.DAG, 'create_dagrun')
def test_trigger_dag(self, mock):
@ -104,4 +121,24 @@ class TestLocalClient(unittest.TestCase):
external_trigger=True)
mock.reset_mock()
# this is a unit test only, cannot verify existing dag run
def test_get_pool(self):
self.client.create_pool(name='foo', slots=1, description='')
pool = self.client.get_pool(name='foo')
self.assertEqual(pool, ('foo', 1, ''))
def test_get_pools(self):
self.client.create_pool(name='foo1', slots=1, description='')
self.client.create_pool(name='foo2', slots=2, description='')
pools = sorted(self.client.get_pools(), key=lambda p: p[0])
self.assertEqual(pools, [('foo1', 1, ''), ('foo2', 2, '')])
def test_create_pool(self):
pool = self.client.create_pool(name='foo', slots=1, description='')
self.assertEqual(pool, ('foo', 1, ''))
self.assertEqual(self.session.query(models.Pool).count(), 1)
def test_delete_pool(self):
self.client.create_pool(name='foo', slots=1, description='')
self.assertEqual(self.session.query(models.Pool).count(), 1)
self.client.delete_pool(name='foo')
self.assertEqual(self.session.query(models.Pool).count(), 0)

Просмотреть файл

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Просмотреть файл

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
@ -27,6 +26,7 @@ DEV_NULL = "/dev/null"
class TestMarkTasks(unittest.TestCase):
def setUp(self):
self.dagbag = models.DagBag(include_examples=True)
self.dag1 = self.dagbag.dags['test_example_bash_operator']
@ -52,6 +52,16 @@ class TestMarkTasks(unittest.TestCase):
self.session = Session()
def tearDown(self):
self.dag1.clear()
self.dag2.clear()
# just to make sure we are fully cleaned up
self.session.query(models.DagRun).delete()
self.session.query(models.TaskInstance).delete()
self.session.commit()
self.session.close()
def snapshot_state(self, dag, execution_dates):
TI = models.TaskInstance
tis = self.session.query(TI).filter(
@ -197,16 +207,6 @@ class TestMarkTasks(unittest.TestCase):
self.verify_state(self.dag2, task_ids, [self.execution_dates[0]],
State.SUCCESS, [])
def tearDown(self):
self.dag1.clear()
self.dag2.clear()
# just to make sure we are fully cleaned up
self.session.query(models.DagRun).delete()
self.session.query(models.TaskInstance).delete()
self.session.commit()
self.session.close()
class TestMarkDAGRun(unittest.TestCase):
def setUp(self):

Просмотреть файл

@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from airflow.api.common.experimental import pool as pool_api
from airflow import models
from airflow import settings
class TestPool(unittest.TestCase):
def setUp(self):
super(TestPool, self).setUp()
self.session = settings.Session()
self.pools = []
for i in range(2):
name = 'experimental_%s' % (i + 1)
pool = models.Pool(
pool=name,
slots=i,
description=name,
)
self.session.add(pool)
self.pools.append(pool)
self.session.commit()
def tearDown(self):
self.session.query(models.Pool).delete()
self.session.commit()
self.session.close()
super(TestPool, self).tearDown()
def test_get_pool(self):
pool = pool_api.get_pool(name=self.pools[0].pool, session=self.session)
self.assertEqual(pool.pool, self.pools[0].pool)
def test_get_pool_non_existing(self):
self.assertRaisesRegexp(pool_api.PoolNotFound,
"^Pool 'test' doesn't exist$",
pool_api.get_pool,
name='test',
session=self.session)
def test_get_pool_bad_name(self):
for name in ('', ' '):
self.assertRaisesRegexp(pool_api.PoolBadRequest,
"^Pool name shouldn't be empty$",
pool_api.get_pool,
name=name,
session=self.session)
def test_get_pools(self):
pools = sorted(pool_api.get_pools(session=self.session),
key=lambda p: p.pool)
self.assertEqual(pools[0].pool, self.pools[0].pool)
self.assertEqual(pools[1].pool, self.pools[1].pool)
def test_create_pool(self):
pool = pool_api.create_pool(name='foo',
slots=5,
description='',
session=self.session)
self.assertEqual(pool.pool, 'foo')
self.assertEqual(pool.slots, 5)
self.assertEqual(pool.description, '')
self.assertEqual(self.session.query(models.Pool).count(), 3)
def test_create_pool_existing(self):
pool = pool_api.create_pool(name=self.pools[0].pool,
slots=5,
description='',
session=self.session)
self.assertEqual(pool.pool, self.pools[0].pool)
self.assertEqual(pool.slots, 5)
self.assertEqual(pool.description, '')
self.assertEqual(self.session.query(models.Pool).count(), 2)
def test_create_pool_bad_name(self):
for name in ('', ' '):
self.assertRaisesRegexp(pool_api.PoolBadRequest,
"^Pool name shouldn't be empty$",
pool_api.create_pool,
name=name,
slots=5,
description='',
session=self.session)
def test_create_pool_bad_slots(self):
self.assertRaisesRegexp(pool_api.PoolBadRequest,
"^Bad value for `slots`: foo$",
pool_api.create_pool,
name='foo',
slots='foo',
description='',
session=self.session)
def test_delete_pool(self):
pool = pool_api.delete_pool(name=self.pools[0].pool,
session=self.session)
self.assertEqual(pool.pool, self.pools[0].pool)
self.assertEqual(self.session.query(models.Pool).count(), 1)
def test_delete_pool_non_existing(self):
self.assertRaisesRegexp(pool_api.PoolNotFound,
"^Pool 'test' doesn't exist$",
pool_api.delete_pool,
name='test',
session=self.session)
def test_delete_pool_bad_name(self):
for name in ('', ' '):
self.assertRaisesRegexp(pool_api.PoolBadRequest,
"^Pool name shouldn't be empty$",
pool_api.delete_pool,
name=name,
session=self.session)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -1062,12 +1062,34 @@ class CoreTest(unittest.TestCase):
class CliTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(CliTests, cls).setUpClass()
cls._cleanup()
def setUp(self):
super(CliTests, self).setUp()
configuration.load_test_config()
app = application.create_app()
app.config['TESTING'] = True
self.parser = cli.CLIFactory.get_parser()
self.dagbag = models.DagBag(dag_folder=DEV_NULL, include_examples=True)
self.session = Session()
def tearDown(self):
self._cleanup(session=self.session)
super(CliTests, self).tearDown()
@staticmethod
def _cleanup(session=None):
if session is None:
session = Session()
session.query(models.Pool).delete()
session.query(models.Variable).delete()
session.commit()
session.close()
def test_cli_list_dags(self):
args = self.parser.parse_args(['list_dags', '--report'])
@ -1100,8 +1122,8 @@ class CliTests(unittest.TestCase):
cli.connections(self.parser.parse_args(['connections', '--list']))
stdout = mock_stdout.getvalue()
conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]]
for ii, line in enumerate(stdout.split('\n'))
if ii % 2 == 1]
for ii, line in enumerate(stdout.split('\n'))
if ii % 2 == 1]
conns = [conn for conn in conns if len(conn) > 0]
# Assert that some of the connections are present in the output as
@ -1365,14 +1387,27 @@ class CliTests(unittest.TestCase):
'-c', 'NOT JSON'])
)
def test_pool(self):
# Checks if all subcommands are properly received
cli.pool(self.parser.parse_args([
'pool', '-s', 'foo', '1', '"my foo pool"']))
cli.pool(self.parser.parse_args([
'pool', '-g', 'foo']))
cli.pool(self.parser.parse_args([
'pool', '-x', 'foo']))
def test_pool_create(self):
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
self.assertEqual(self.session.query(models.Pool).count(), 1)
def test_pool_get(self):
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
try:
cli.pool(self.parser.parse_args(['pool', '-g', 'foo']))
except Exception as e:
self.fail("The 'pool -g foo' command raised unexpectedly: %s" % e)
def test_pool_delete(self):
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
cli.pool(self.parser.parse_args(['pool', '-x', 'foo']))
self.assertEqual(self.session.query(models.Pool).count(), 0)
def test_pool_no_args(self):
try:
cli.pool(self.parser.parse_args(['pool']))
except Exception as e:
self.fail("The 'pool' command raised unexpectedly: %s" % e)
def test_variables(self):
# Checks if all subcommands are properly received
@ -1426,10 +1461,6 @@ class CliTests(unittest.TestCase):
self.assertEqual('original', models.Variable.get('bar'))
self.assertEqual('{"foo": "bar"}', models.Variable.get('foo'))
session = settings.Session()
session.query(Variable).delete()
session.commit()
session.close()
os.remove('variables1.json')
os.remove('variables2.json')

Просмотреть файл

@ -19,23 +19,36 @@ from urllib.parse import quote_plus
from airflow import configuration
from airflow.api.common.experimental.trigger_dag import trigger_dag
from airflow.models import DagBag, DagRun, TaskInstance
from airflow.models import DagBag, DagRun, Pool, TaskInstance
from airflow.settings import Session
from airflow.www import app as application
class ApiExperimentalTests(unittest.TestCase):
class TestApiExperimental(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
app = application.create_app(testing=True)
self.app = app.test_client()
@classmethod
def setUpClass(cls):
super(TestApiExperimental, cls).setUpClass()
session = Session()
session.query(DagRun).delete()
session.query(TaskInstance).delete()
session.commit()
session.close()
def setUp(self):
super(TestApiExperimental, self).setUp()
configuration.load_test_config()
app = application.create_app(testing=True)
self.app = app.test_client()
def tearDown(self):
session = Session()
session.query(DagRun).delete()
session.query(TaskInstance).delete()
session.commit()
session.close()
super(TestApiExperimental, self).tearDown()
def test_task_info(self):
url_template = '/api/experimental/dags/{}/tasks/{}'
@ -62,7 +75,7 @@ class ApiExperimentalTests(unittest.TestCase):
url_template = '/api/experimental/dags/{}/dag_runs'
response = self.app.post(
url_template.format('example_bash_operator'),
data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())),
data=json.dumps({'run_id': 'my_run' + datetime.now().isoformat()}),
content_type="application/json"
)
@ -70,7 +83,7 @@ class ApiExperimentalTests(unittest.TestCase):
response = self.app.post(
url_template.format('does_not_exist_dag'),
data=json.dumps(dict()),
data=json.dumps({}),
content_type="application/json"
)
self.assertEqual(404, response.status_code)
@ -88,7 +101,7 @@ class ApiExperimentalTests(unittest.TestCase):
# Test Correct execution
response = self.app.post(
url_template.format(dag_id),
data=json.dumps(dict(execution_date=datetime_string)),
data=json.dumps({'execution_date': datetime_string}),
content_type="application/json"
)
self.assertEqual(200, response.status_code)
@ -103,7 +116,7 @@ class ApiExperimentalTests(unittest.TestCase):
# Test error for nonexistent dag
response = self.app.post(
url_template.format('does_not_exist_dag'),
data=json.dumps(dict(execution_date=execution_date.isoformat())),
data=json.dumps({'execution_date': execution_date.isoformat()}),
content_type="application/json"
)
self.assertEqual(404, response.status_code)
@ -111,7 +124,7 @@ class ApiExperimentalTests(unittest.TestCase):
# Test error for bad datetime format
response = self.app.post(
url_template.format(dag_id),
data=json.dumps(dict(execution_date='not_a_datetime')),
data=json.dumps({'execution_date': 'not_a_datetime'}),
content_type="application/json"
)
self.assertEqual(400, response.status_code)
@ -122,7 +135,9 @@ class ApiExperimentalTests(unittest.TestCase):
task_id = 'also_run_this'
execution_date = datetime.now().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat())
wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat())
wrong_datetime_string = quote_plus(
datetime(1990, 1, 1, 1, 1, 1).isoformat()
)
# Create DagRun
trigger_dag(dag_id=dag_id,
@ -139,7 +154,8 @@ class ApiExperimentalTests(unittest.TestCase):
# Test error for nonexistent dag
response = self.app.get(
url_template.format('does_not_exist_dag', datetime_string, task_id),
url_template.format('does_not_exist_dag', datetime_string,
task_id),
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
@ -164,3 +180,122 @@ class ApiExperimentalTests(unittest.TestCase):
)
self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
class TestPoolApiExperimental(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(TestPoolApiExperimental, cls).setUpClass()
session = Session()
session.query(Pool).delete()
session.commit()
session.close()
def setUp(self):
super(TestPoolApiExperimental, self).setUp()
configuration.load_test_config()
app = application.create_app(testing=True)
self.app = app.test_client()
self.session = Session()
self.pools = []
for i in range(2):
name = 'experimental_%s' % (i + 1)
pool = Pool(
pool=name,
slots=i,
description=name,
)
self.session.add(pool)
self.pools.append(pool)
self.session.commit()
self.pool = self.pools[0]
def tearDown(self):
self.session.query(Pool).delete()
self.session.commit()
self.session.close()
super(TestPoolApiExperimental, self).tearDown()
def _get_pool_count(self):
response = self.app.get('/api/experimental/pools')
self.assertEqual(response.status_code, 200)
return len(json.loads(response.data.decode('utf-8')))
def test_get_pool(self):
response = self.app.get(
'/api/experimental/pools/{}'.format(self.pool.pool),
)
self.assertEqual(response.status_code, 200)
self.assertEqual(json.loads(response.data.decode('utf-8')),
self.pool.to_json())
def test_get_pool_non_existing(self):
response = self.app.get('/api/experimental/pools/foo')
self.assertEqual(response.status_code, 404)
self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
"Pool 'foo' doesn't exist")
def test_get_pools(self):
response = self.app.get('/api/experimental/pools')
self.assertEqual(response.status_code, 200)
pools = json.loads(response.data.decode('utf-8'))
self.assertEqual(len(pools), 2)
for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])):
self.assertDictEqual(pool, self.pools[i].to_json())
def test_create_pool(self):
response = self.app.post(
'/api/experimental/pools',
data=json.dumps({
'name': 'foo',
'slots': 1,
'description': '',
}),
content_type='application/json',
)
self.assertEqual(response.status_code, 200)
pool = json.loads(response.data.decode('utf-8'))
self.assertEqual(pool['pool'], 'foo')
self.assertEqual(pool['slots'], 1)
self.assertEqual(pool['description'], '')
self.assertEqual(self._get_pool_count(), 3)
def test_create_pool_with_bad_name(self):
for name in ('', ' '):
response = self.app.post(
'/api/experimental/pools',
data=json.dumps({
'name': name,
'slots': 1,
'description': '',
}),
content_type='application/json',
)
self.assertEqual(response.status_code, 400)
self.assertEqual(
json.loads(response.data.decode('utf-8'))['error'],
"Pool name shouldn't be empty",
)
self.assertEqual(self._get_pool_count(), 2)
def test_delete_pool(self):
response = self.app.delete(
'/api/experimental/pools/{}'.format(self.pool.pool),
)
self.assertEqual(response.status_code, 200)
self.assertEqual(json.loads(response.data.decode('utf-8')),
self.pool.to_json())
self.assertEqual(self._get_pool_count(), 1)
def test_delete_pool_non_existing(self):
response = self.app.delete(
'/api/experimental/pools/foo',
)
self.assertEqual(response.status_code, 404)
self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
"Pool 'foo' doesn't exist")
if __name__ == '__main__':
unittest.main()