diff --git a/airflow/api/client/api_client.py b/airflow/api/client/api_client.py index 6a775384ff..f24d80945f 100644 --- a/airflow/api/client/api_client.py +++ b/airflow/api/client/api_client.py @@ -14,17 +14,47 @@ # -class Client: +class Client(object): + """Base API client for all API clients.""" + def __init__(self, api_base_url, auth): self._api_base_url = api_base_url self._auth = auth def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): - """ - Creates a dag run for the specified dag + """Create a dag run for the specified dag. + :param dag_id: :param run_id: :param conf: + :param execution_date: :return: """ raise NotImplementedError() + + def get_pool(self, name): + """Get pool. + + :param name: pool name + """ + raise NotImplementedError() + + def get_pools(self): + """Get all pools.""" + raise NotImplementedError() + + def create_pool(self, name, slots, description): + """Create a pool. + + :param name: pool name + :param slots: pool slots amount + :param description: pool description + """ + raise NotImplementedError() + + def delete_pool(self, name): + """Delete pool. + + :param name: pool name + """ + raise NotImplementedError() diff --git a/airflow/api/client/json_client.py b/airflow/api/client/json_client.py index d74fc636cd..37e24d3c4e 100644 --- a/airflow/api/client/json_client.py +++ b/airflow/api/client/json_client.py @@ -11,30 +11,70 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + from future.moves.urllib.parse import urljoin +import requests from airflow.api.client import api_client -import requests - class Client(api_client.Client): + """Json API client implementation.""" + + def _request(self, url, method='GET', json=None): + params = { + 'url': url, + 'auth': self._auth, + } + if json is not None: + params['json'] = json + + resp = getattr(requests, method.lower())(**params) + if not resp.ok: + try: + data = resp.json() + except Exception: + data = {} + raise IOError(data.get('error', 'Server error')) + + return resp.json() + def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): endpoint = '/api/experimental/dags/{}/dag_runs'.format(dag_id) url = urljoin(self._api_base_url, endpoint) - - resp = requests.post(url, - auth=self._auth, + data = self._request(url, method='POST', json={ "run_id": run_id, "conf": conf, "execution_date": execution_date, }) - - if not resp.ok: - raise IOError() - - data = resp.json() - return data['message'] + + def get_pool(self, name): + endpoint = '/api/experimental/pools/{}'.format(name) + url = urljoin(self._api_base_url, endpoint) + pool = self._request(url) + return pool['pool'], pool['slots'], pool['description'] + + def get_pools(self): + endpoint = '/api/experimental/pools' + url = urljoin(self._api_base_url, endpoint) + pools = self._request(url) + return [(p['pool'], p['slots'], p['description']) for p in pools] + + def create_pool(self, name, slots, description): + endpoint = '/api/experimental/pools' + url = urljoin(self._api_base_url, endpoint) + pool = self._request(url, method='POST', + json={ + 'name': name, + 'slots': slots, + 'description': description, + }) + return pool['pool'], pool['slots'], pool['description'] + + def delete_pool(self, name): + endpoint = '/api/experimental/pools/{}'.format(name) + url = urljoin(self._api_base_url, endpoint) + pool = self._request(url, method='DELETE') + return pool['pool'], pool['slots'], pool['description'] diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py index 05f27f6914..5bc7f76aaa 100644 --- a/airflow/api/client/local_client.py +++ b/airflow/api/client/local_client.py @@ -11,15 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + from airflow.api.client import api_client +from airflow.api.common.experimental import pool from airflow.api.common.experimental import trigger_dag class Client(api_client.Client): + """Local API client implementation.""" + def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): dr = trigger_dag.trigger_dag(dag_id=dag_id, run_id=run_id, conf=conf, execution_date=execution_date) return "Created {}".format(dr) + + def get_pool(self, name): + p = pool.get_pool(name=name) + return p.pool, p.slots, p.description + + def get_pools(self): + return [(p.pool, p.slots, p.description) for p in pool.get_pools()] + + def create_pool(self, name, slots, description): + p = pool.create_pool(name=name, slots=slots, description=description) + return p.pool, p.slots, p.description + + def delete_pool(self, name): + p = pool.delete_pool(name=name) + return p.pool, p.slots, p.description diff --git a/airflow/api/common/experimental/pool.py b/airflow/api/common/experimental/pool.py new file mode 100644 index 0000000000..6e963a2fd2 --- /dev/null +++ b/airflow/api/common/experimental/pool.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from airflow.exceptions import AirflowException +from airflow.models import Pool +from airflow.utils.db import provide_session + + +class PoolBadRequest(AirflowException): + status = 400 + + +class PoolNotFound(AirflowException): + status = 404 + + +@provide_session +def get_pool(name, session=None): + """Get pool by a given name.""" + if not (name and name.strip()): + raise PoolBadRequest("Pool name shouldn't be empty") + + pool = session.query(Pool).filter_by(pool=name).first() + if pool is None: + raise PoolNotFound("Pool '%s' doesn't exist" % name) + + return pool + + +@provide_session +def get_pools(session=None): + """Get all pools.""" + return session.query(Pool).all() + + +@provide_session +def create_pool(name, slots, description, session=None): + """Create a pool with a given parameters.""" + if not (name and name.strip()): + raise PoolBadRequest("Pool name shouldn't be empty") + + try: + slots = int(slots) + except ValueError: + raise PoolBadRequest("Bad value for `slots`: %s" % slots) + + session.expire_on_commit = False + pool = session.query(Pool).filter_by(pool=name).first() + if pool is None: + pool = Pool(pool=name, slots=slots, description=description) + session.add(pool) + else: + pool.slots = slots + pool.description = description + + session.commit() + + return pool + + +@provide_session +def delete_pool(name, session=None): + """Delete pool by a given name.""" + if not (name and name.strip()): + raise PoolBadRequest("Pool name shouldn't be empty") + + pool = session.query(Pool).filter_by(pool=name).first() + if pool is None: + raise PoolNotFound("Pool '%s' doesn't exist" % name) + + session.delete(pool) + session.commit() + + return pool diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 41f979fa51..4b3a0edc8f 100755 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -49,7 +49,7 @@ from airflow.exceptions import AirflowException from airflow.executors import GetDefaultExecutor from airflow.models import (DagModel, DagBag, TaskInstance, DagPickle, DagRun, Variable, DagStat, - Pool, Connection) + Connection) from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS) from airflow.utils import db as db_utils from airflow.utils import logging as logging_utils @@ -187,40 +187,28 @@ def trigger_dag(args): def pool(args): - session = settings.Session() - if args.get or (args.set and args.set[0]) or args.delete: - name = args.get or args.delete or args.set[0] - pool = ( - session.query(Pool) - .filter(Pool.pool == name) - .first()) - if pool and args.get: - print("{} ".format(pool)) - return - elif not pool and (args.get or args.delete): - print("No pool named {} found".format(name)) - elif not pool and args.set: - pool = Pool( - pool=name, - slots=args.set[1], - description=args.set[2]) - session.add(pool) - session.commit() - print("{} ".format(pool)) - elif pool and args.set: - pool.slots = args.set[1] - pool.description = args.set[2] - session.commit() - print("{} ".format(pool)) - return - elif pool and args.delete: - session.query(Pool).filter_by(pool=args.delete).delete() - session.commit() - print("Pool {} deleted".format(name)) + def _tabulate(pools): + return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'], + tablefmt="fancy_grid") + + try: + if args.get is not None: + pools = [api_client.get_pool(name=args.get)] + elif args.set: + pools = [api_client.create_pool(name=args.set[0], + slots=args.set[1], + description=args.set[2])] + elif args.delete: + pools = [api_client.delete_pool(name=args.delete)] + else: + pools = api_client.get_pools() + except (AirflowException, IOError) as err: + logging.error(err) + else: + logging.info(_tabulate(pools=pools)) def variables(args): - if args.get: try: var = Variable.get(args.get, diff --git a/airflow/models.py b/airflow/models.py index 2c433ad15d..000257218c 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -4395,6 +4395,14 @@ class Pool(Base): def __repr__(self): return self.pool + def to_json(self): + return { + 'id': self.id, + 'pool': self.pool, + 'slots': self.slots, + 'description': self.description, + } + @provide_session def used_slots(self, session): """ diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py index be9273574d..a8d7f5c5aa 100644 --- a/airflow/www/api/experimental/endpoints.py +++ b/airflow/www/api/experimental/endpoints.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging import airflow.api +from airflow.api.common.experimental import pool as pool_api from airflow.api.common.experimental import trigger_dag as trigger from airflow.api.common.experimental.get_task import get_task from airflow.api.common.experimental.get_task_instance import get_task_instance @@ -96,7 +98,6 @@ def test(): @requires_authentication def task_info(dag_id, task_id): """Returns a JSON with a task's public instance variables. """ - try: info = get_task(dag_id, task_id) except AirflowException as err: @@ -169,4 +170,67 @@ def latest_dag_runs(): 'dag_run_url': url_for('airflow.graph', dag_id=dagrun.dag_id, execution_date=dagrun.execution_date) }) - return jsonify(items=payload) # old flask versions dont support jsonifying arrays + return jsonify(items=payload) # old flask versions dont support jsonifying arrays + + +@api_experimental.route('/pools/', methods=['GET']) +@requires_authentication +def get_pool(name): + """Get pool by a given name.""" + try: + pool = pool_api.get_pool(name=name) + except AirflowException as e: + _log.error(e) + response = jsonify(error="{}".format(e)) + response.status_code = getattr(e, 'status', 500) + return response + else: + return jsonify(pool.to_json()) + + +@api_experimental.route('/pools', methods=['GET']) +@requires_authentication +def get_pools(): + """Get all pools.""" + try: + pools = pool_api.get_pools() + except AirflowException as e: + _log.error(e) + response = jsonify(error="{}".format(e)) + response.status_code = getattr(e, 'status', 500) + return response + else: + return jsonify([p.to_json() for p in pools]) + + +@csrf.exempt +@api_experimental.route('/pools', methods=['POST']) +@requires_authentication +def create_pool(): + """Create a pool.""" + params = request.get_json(force=True) + try: + pool = pool_api.create_pool(**params) + except AirflowException as e: + _log.error(e) + response = jsonify(error="{}".format(e)) + response.status_code = getattr(e, 'status', 500) + return response + else: + return jsonify(pool.to_json()) + + +@csrf.exempt +@api_experimental.route('/pools/', methods=['DELETE']) +@requires_authentication +def delete_pool(name): + """Delete pool.""" + try: + pool = pool_api.delete_pool(name=name) + except AirflowException as e: + _log.error(e) + response = jsonify(error="{}".format(e)) + response.status_code = getattr(e, 'status', 500) + return response + else: + return jsonify(pool.to_json()) diff --git a/tests/api/__init__.py b/tests/api/__init__.py index 37d59f0d34..9d7677a99b 100644 --- a/tests/api/__init__.py +++ b/tests/api/__init__.py @@ -11,9 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import absolute_import - -from .client import * -from .common import * - diff --git a/tests/api/client/local_client.py b/tests/api/client/test_local_client.py similarity index 72% rename from tests/api/client/local_client.py rename to tests/api/client/test_local_client.py index a36b71f01f..7a759fe6a5 100644 --- a/tests/api/client/local_client.py +++ b/tests/api/client/test_local_client.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import json import unittest -import datetime from mock import patch from airflow import AirflowException -from airflow import models - from airflow.api.client.local_client import Client +from airflow import models +from airflow import settings from airflow.utils.state import State EXECDATE = datetime.datetime.now() @@ -53,8 +53,25 @@ def mock_datetime_now(target, dt): class TestLocalClient(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(TestLocalClient, cls).setUpClass() + session = settings.Session() + session.query(models.Pool).delete() + session.commit() + session.close() + def setUp(self): + super(TestLocalClient, self).setUp() self.client = Client(api_base_url=None, auth=None) + self.session = settings.Session() + + def tearDown(self): + self.session.query(models.Pool).delete() + self.session.commit() + self.session.close() + super(TestLocalClient, self).tearDown() @patch.object(models.DAG, 'create_dagrun') def test_trigger_dag(self, mock): @@ -104,4 +121,24 @@ class TestLocalClient(unittest.TestCase): external_trigger=True) mock.reset_mock() - # this is a unit test only, cannot verify existing dag run + def test_get_pool(self): + self.client.create_pool(name='foo', slots=1, description='') + pool = self.client.get_pool(name='foo') + self.assertEqual(pool, ('foo', 1, '')) + + def test_get_pools(self): + self.client.create_pool(name='foo1', slots=1, description='') + self.client.create_pool(name='foo2', slots=2, description='') + pools = sorted(self.client.get_pools(), key=lambda p: p[0]) + self.assertEqual(pools, [('foo1', 1, ''), ('foo2', 2, '')]) + + def test_create_pool(self): + pool = self.client.create_pool(name='foo', slots=1, description='') + self.assertEqual(pool, ('foo', 1, '')) + self.assertEqual(self.session.query(models.Pool).count(), 1) + + def test_delete_pool(self): + self.client.create_pool(name='foo', slots=1, description='') + self.assertEqual(self.session.query(models.Pool).count(), 1) + self.client.delete_pool(name='foo') + self.assertEqual(self.session.query(models.Pool).count(), 0) diff --git a/tests/api/common/experimental/__init__.py b/tests/api/common/experimental/__init__.py new file mode 100644 index 0000000000..9d7677a99b --- /dev/null +++ b/tests/api/common/experimental/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/api/common/mark_tasks.py b/tests/api/common/experimental/mark_tasks.py similarity index 99% rename from tests/api/common/mark_tasks.py rename to tests/api/common/experimental/mark_tasks.py index 8a3759f8ba..e4395aeaa8 100644 --- a/tests/api/common/mark_tasks.py +++ b/tests/api/common/experimental/mark_tasks.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# import unittest @@ -27,6 +26,7 @@ DEV_NULL = "/dev/null" class TestMarkTasks(unittest.TestCase): + def setUp(self): self.dagbag = models.DagBag(include_examples=True) self.dag1 = self.dagbag.dags['test_example_bash_operator'] @@ -52,6 +52,16 @@ class TestMarkTasks(unittest.TestCase): self.session = Session() + def tearDown(self): + self.dag1.clear() + self.dag2.clear() + + # just to make sure we are fully cleaned up + self.session.query(models.DagRun).delete() + self.session.query(models.TaskInstance).delete() + self.session.commit() + self.session.close() + def snapshot_state(self, dag, execution_dates): TI = models.TaskInstance tis = self.session.query(TI).filter( @@ -197,16 +207,6 @@ class TestMarkTasks(unittest.TestCase): self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], State.SUCCESS, []) - def tearDown(self): - self.dag1.clear() - self.dag2.clear() - - # just to make sure we are fully cleaned up - self.session.query(models.DagRun).delete() - self.session.query(models.TaskInstance).delete() - self.session.commit() - - self.session.close() class TestMarkDAGRun(unittest.TestCase): def setUp(self): diff --git a/tests/api/common/experimental/test_pool.py b/tests/api/common/experimental/test_pool.py new file mode 100644 index 0000000000..98969b8ce5 --- /dev/null +++ b/tests/api/common/experimental/test_pool.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from airflow.api.common.experimental import pool as pool_api +from airflow import models +from airflow import settings + + +class TestPool(unittest.TestCase): + + def setUp(self): + super(TestPool, self).setUp() + self.session = settings.Session() + self.pools = [] + for i in range(2): + name = 'experimental_%s' % (i + 1) + pool = models.Pool( + pool=name, + slots=i, + description=name, + ) + self.session.add(pool) + self.pools.append(pool) + self.session.commit() + + def tearDown(self): + self.session.query(models.Pool).delete() + self.session.commit() + self.session.close() + super(TestPool, self).tearDown() + + def test_get_pool(self): + pool = pool_api.get_pool(name=self.pools[0].pool, session=self.session) + self.assertEqual(pool.pool, self.pools[0].pool) + + def test_get_pool_non_existing(self): + self.assertRaisesRegexp(pool_api.PoolNotFound, + "^Pool 'test' doesn't exist$", + pool_api.get_pool, + name='test', + session=self.session) + + def test_get_pool_bad_name(self): + for name in ('', ' '): + self.assertRaisesRegexp(pool_api.PoolBadRequest, + "^Pool name shouldn't be empty$", + pool_api.get_pool, + name=name, + session=self.session) + + def test_get_pools(self): + pools = sorted(pool_api.get_pools(session=self.session), + key=lambda p: p.pool) + self.assertEqual(pools[0].pool, self.pools[0].pool) + self.assertEqual(pools[1].pool, self.pools[1].pool) + + def test_create_pool(self): + pool = pool_api.create_pool(name='foo', + slots=5, + description='', + session=self.session) + self.assertEqual(pool.pool, 'foo') + self.assertEqual(pool.slots, 5) + self.assertEqual(pool.description, '') + self.assertEqual(self.session.query(models.Pool).count(), 3) + + def test_create_pool_existing(self): + pool = pool_api.create_pool(name=self.pools[0].pool, + slots=5, + description='', + session=self.session) + self.assertEqual(pool.pool, self.pools[0].pool) + self.assertEqual(pool.slots, 5) + self.assertEqual(pool.description, '') + self.assertEqual(self.session.query(models.Pool).count(), 2) + + def test_create_pool_bad_name(self): + for name in ('', ' '): + self.assertRaisesRegexp(pool_api.PoolBadRequest, + "^Pool name shouldn't be empty$", + pool_api.create_pool, + name=name, + slots=5, + description='', + session=self.session) + + def test_create_pool_bad_slots(self): + self.assertRaisesRegexp(pool_api.PoolBadRequest, + "^Bad value for `slots`: foo$", + pool_api.create_pool, + name='foo', + slots='foo', + description='', + session=self.session) + + def test_delete_pool(self): + pool = pool_api.delete_pool(name=self.pools[0].pool, + session=self.session) + self.assertEqual(pool.pool, self.pools[0].pool) + self.assertEqual(self.session.query(models.Pool).count(), 1) + + def test_delete_pool_non_existing(self): + self.assertRaisesRegexp(pool_api.PoolNotFound, + "^Pool 'test' doesn't exist$", + pool_api.delete_pool, + name='test', + session=self.session) + + def test_delete_pool_bad_name(self): + for name in ('', ' '): + self.assertRaisesRegexp(pool_api.PoolBadRequest, + "^Pool name shouldn't be empty$", + pool_api.delete_pool, + name=name, + session=self.session) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/core.py b/tests/core.py index 8ccd4e71ff..259b61da95 100644 --- a/tests/core.py +++ b/tests/core.py @@ -1062,12 +1062,34 @@ class CoreTest(unittest.TestCase): class CliTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(CliTests, cls).setUpClass() + cls._cleanup() + def setUp(self): + super(CliTests, self).setUp() configuration.load_test_config() app = application.create_app() app.config['TESTING'] = True self.parser = cli.CLIFactory.get_parser() self.dagbag = models.DagBag(dag_folder=DEV_NULL, include_examples=True) + self.session = Session() + + def tearDown(self): + self._cleanup(session=self.session) + super(CliTests, self).tearDown() + + @staticmethod + def _cleanup(session=None): + if session is None: + session = Session() + + session.query(models.Pool).delete() + session.query(models.Variable).delete() + session.commit() + session.close() def test_cli_list_dags(self): args = self.parser.parse_args(['list_dags', '--report']) @@ -1100,8 +1122,8 @@ class CliTests(unittest.TestCase): cli.connections(self.parser.parse_args(['connections', '--list'])) stdout = mock_stdout.getvalue() conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]] - for ii, line in enumerate(stdout.split('\n')) - if ii % 2 == 1] + for ii, line in enumerate(stdout.split('\n')) + if ii % 2 == 1] conns = [conn for conn in conns if len(conn) > 0] # Assert that some of the connections are present in the output as @@ -1365,14 +1387,27 @@ class CliTests(unittest.TestCase): '-c', 'NOT JSON']) ) - def test_pool(self): - # Checks if all subcommands are properly received - cli.pool(self.parser.parse_args([ - 'pool', '-s', 'foo', '1', '"my foo pool"'])) - cli.pool(self.parser.parse_args([ - 'pool', '-g', 'foo'])) - cli.pool(self.parser.parse_args([ - 'pool', '-x', 'foo'])) + def test_pool_create(self): + cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test'])) + self.assertEqual(self.session.query(models.Pool).count(), 1) + + def test_pool_get(self): + cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test'])) + try: + cli.pool(self.parser.parse_args(['pool', '-g', 'foo'])) + except Exception as e: + self.fail("The 'pool -g foo' command raised unexpectedly: %s" % e) + + def test_pool_delete(self): + cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test'])) + cli.pool(self.parser.parse_args(['pool', '-x', 'foo'])) + self.assertEqual(self.session.query(models.Pool).count(), 0) + + def test_pool_no_args(self): + try: + cli.pool(self.parser.parse_args(['pool'])) + except Exception as e: + self.fail("The 'pool' command raised unexpectedly: %s" % e) def test_variables(self): # Checks if all subcommands are properly received @@ -1426,10 +1461,6 @@ class CliTests(unittest.TestCase): self.assertEqual('original', models.Variable.get('bar')) self.assertEqual('{"foo": "bar"}', models.Variable.get('foo')) - session = settings.Session() - session.query(Variable).delete() - session.commit() - session.close() os.remove('variables1.json') os.remove('variables2.json') diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py index dacee321ec..65a6f75864 100644 --- a/tests/www/api/experimental/test_endpoints.py +++ b/tests/www/api/experimental/test_endpoints.py @@ -19,23 +19,36 @@ from urllib.parse import quote_plus from airflow import configuration from airflow.api.common.experimental.trigger_dag import trigger_dag -from airflow.models import DagBag, DagRun, TaskInstance +from airflow.models import DagBag, DagRun, Pool, TaskInstance from airflow.settings import Session from airflow.www import app as application -class ApiExperimentalTests(unittest.TestCase): +class TestApiExperimental(unittest.TestCase): - def setUp(self): - configuration.load_test_config() - app = application.create_app(testing=True) - self.app = app.test_client() + @classmethod + def setUpClass(cls): + super(TestApiExperimental, cls).setUpClass() session = Session() session.query(DagRun).delete() session.query(TaskInstance).delete() session.commit() session.close() + def setUp(self): + super(TestApiExperimental, self).setUp() + configuration.load_test_config() + app = application.create_app(testing=True) + self.app = app.test_client() + + def tearDown(self): + session = Session() + session.query(DagRun).delete() + session.query(TaskInstance).delete() + session.commit() + session.close() + super(TestApiExperimental, self).tearDown() + def test_task_info(self): url_template = '/api/experimental/dags/{}/tasks/{}' @@ -62,7 +75,7 @@ class ApiExperimentalTests(unittest.TestCase): url_template = '/api/experimental/dags/{}/dag_runs' response = self.app.post( url_template.format('example_bash_operator'), - data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())), + data=json.dumps({'run_id': 'my_run' + datetime.now().isoformat()}), content_type="application/json" ) @@ -70,7 +83,7 @@ class ApiExperimentalTests(unittest.TestCase): response = self.app.post( url_template.format('does_not_exist_dag'), - data=json.dumps(dict()), + data=json.dumps({}), content_type="application/json" ) self.assertEqual(404, response.status_code) @@ -88,7 +101,7 @@ class ApiExperimentalTests(unittest.TestCase): # Test Correct execution response = self.app.post( url_template.format(dag_id), - data=json.dumps(dict(execution_date=datetime_string)), + data=json.dumps({'execution_date': datetime_string}), content_type="application/json" ) self.assertEqual(200, response.status_code) @@ -103,7 +116,7 @@ class ApiExperimentalTests(unittest.TestCase): # Test error for nonexistent dag response = self.app.post( url_template.format('does_not_exist_dag'), - data=json.dumps(dict(execution_date=execution_date.isoformat())), + data=json.dumps({'execution_date': execution_date.isoformat()}), content_type="application/json" ) self.assertEqual(404, response.status_code) @@ -111,7 +124,7 @@ class ApiExperimentalTests(unittest.TestCase): # Test error for bad datetime format response = self.app.post( url_template.format(dag_id), - data=json.dumps(dict(execution_date='not_a_datetime')), + data=json.dumps({'execution_date': 'not_a_datetime'}), content_type="application/json" ) self.assertEqual(400, response.status_code) @@ -122,7 +135,9 @@ class ApiExperimentalTests(unittest.TestCase): task_id = 'also_run_this' execution_date = datetime.now().replace(microsecond=0) datetime_string = quote_plus(execution_date.isoformat()) - wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat()) + wrong_datetime_string = quote_plus( + datetime(1990, 1, 1, 1, 1, 1).isoformat() + ) # Create DagRun trigger_dag(dag_id=dag_id, @@ -139,7 +154,8 @@ class ApiExperimentalTests(unittest.TestCase): # Test error for nonexistent dag response = self.app.get( - url_template.format('does_not_exist_dag', datetime_string, task_id), + url_template.format('does_not_exist_dag', datetime_string, + task_id), ) self.assertEqual(404, response.status_code) self.assertIn('error', response.data.decode('utf-8')) @@ -164,3 +180,122 @@ class ApiExperimentalTests(unittest.TestCase): ) self.assertEqual(400, response.status_code) self.assertIn('error', response.data.decode('utf-8')) + + +class TestPoolApiExperimental(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(TestPoolApiExperimental, cls).setUpClass() + session = Session() + session.query(Pool).delete() + session.commit() + session.close() + + def setUp(self): + super(TestPoolApiExperimental, self).setUp() + configuration.load_test_config() + app = application.create_app(testing=True) + self.app = app.test_client() + self.session = Session() + self.pools = [] + for i in range(2): + name = 'experimental_%s' % (i + 1) + pool = Pool( + pool=name, + slots=i, + description=name, + ) + self.session.add(pool) + self.pools.append(pool) + self.session.commit() + self.pool = self.pools[0] + + def tearDown(self): + self.session.query(Pool).delete() + self.session.commit() + self.session.close() + super(TestPoolApiExperimental, self).tearDown() + + def _get_pool_count(self): + response = self.app.get('/api/experimental/pools') + self.assertEqual(response.status_code, 200) + return len(json.loads(response.data.decode('utf-8'))) + + def test_get_pool(self): + response = self.app.get( + '/api/experimental/pools/{}'.format(self.pool.pool), + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(json.loads(response.data.decode('utf-8')), + self.pool.to_json()) + + def test_get_pool_non_existing(self): + response = self.app.get('/api/experimental/pools/foo') + self.assertEqual(response.status_code, 404) + self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], + "Pool 'foo' doesn't exist") + + def test_get_pools(self): + response = self.app.get('/api/experimental/pools') + self.assertEqual(response.status_code, 200) + pools = json.loads(response.data.decode('utf-8')) + self.assertEqual(len(pools), 2) + for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])): + self.assertDictEqual(pool, self.pools[i].to_json()) + + def test_create_pool(self): + response = self.app.post( + '/api/experimental/pools', + data=json.dumps({ + 'name': 'foo', + 'slots': 1, + 'description': '', + }), + content_type='application/json', + ) + self.assertEqual(response.status_code, 200) + pool = json.loads(response.data.decode('utf-8')) + self.assertEqual(pool['pool'], 'foo') + self.assertEqual(pool['slots'], 1) + self.assertEqual(pool['description'], '') + self.assertEqual(self._get_pool_count(), 3) + + def test_create_pool_with_bad_name(self): + for name in ('', ' '): + response = self.app.post( + '/api/experimental/pools', + data=json.dumps({ + 'name': name, + 'slots': 1, + 'description': '', + }), + content_type='application/json', + ) + self.assertEqual(response.status_code, 400) + self.assertEqual( + json.loads(response.data.decode('utf-8'))['error'], + "Pool name shouldn't be empty", + ) + self.assertEqual(self._get_pool_count(), 2) + + def test_delete_pool(self): + response = self.app.delete( + '/api/experimental/pools/{}'.format(self.pool.pool), + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(json.loads(response.data.decode('utf-8')), + self.pool.to_json()) + self.assertEqual(self._get_pool_count(), 1) + + def test_delete_pool_non_existing(self): + response = self.app.delete( + '/api/experimental/pools/foo', + ) + self.assertEqual(response.status_code, 404) + self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], + "Pool 'foo' doesn't exist") + + +if __name__ == '__main__': + unittest.main()