[AIRFLOW-1275] Put 'airflow pool' into API
Closes #2346 from skudriashev/airflow-1275
This commit is contained in:
Родитель
a45e2d1888
Коммит
9958aa9d53
|
@ -14,17 +14,47 @@
|
|||
#
|
||||
|
||||
|
||||
class Client:
|
||||
class Client(object):
|
||||
"""Base API client for all API clients."""
|
||||
|
||||
def __init__(self, api_base_url, auth):
|
||||
self._api_base_url = api_base_url
|
||||
self._auth = auth
|
||||
|
||||
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
|
||||
"""
|
||||
Creates a dag run for the specified dag
|
||||
"""Create a dag run for the specified dag.
|
||||
|
||||
:param dag_id:
|
||||
:param run_id:
|
||||
:param conf:
|
||||
:param execution_date:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_pool(self, name):
|
||||
"""Get pool.
|
||||
|
||||
:param name: pool name
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_pools(self):
|
||||
"""Get all pools."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_pool(self, name, slots, description):
|
||||
"""Create a pool.
|
||||
|
||||
:param name: pool name
|
||||
:param slots: pool slots amount
|
||||
:param description: pool description
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def delete_pool(self, name):
|
||||
"""Delete pool.
|
||||
|
||||
:param name: pool name
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -11,30 +11,70 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from future.moves.urllib.parse import urljoin
|
||||
import requests
|
||||
|
||||
from airflow.api.client import api_client
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class Client(api_client.Client):
|
||||
"""Json API client implementation."""
|
||||
|
||||
def _request(self, url, method='GET', json=None):
|
||||
params = {
|
||||
'url': url,
|
||||
'auth': self._auth,
|
||||
}
|
||||
if json is not None:
|
||||
params['json'] = json
|
||||
|
||||
resp = getattr(requests, method.lower())(**params)
|
||||
if not resp.ok:
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
data = {}
|
||||
raise IOError(data.get('error', 'Server error'))
|
||||
|
||||
return resp.json()
|
||||
|
||||
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
|
||||
endpoint = '/api/experimental/dags/{}/dag_runs'.format(dag_id)
|
||||
url = urljoin(self._api_base_url, endpoint)
|
||||
|
||||
resp = requests.post(url,
|
||||
auth=self._auth,
|
||||
data = self._request(url, method='POST',
|
||||
json={
|
||||
"run_id": run_id,
|
||||
"conf": conf,
|
||||
"execution_date": execution_date,
|
||||
})
|
||||
|
||||
if not resp.ok:
|
||||
raise IOError()
|
||||
|
||||
data = resp.json()
|
||||
|
||||
return data['message']
|
||||
|
||||
def get_pool(self, name):
|
||||
endpoint = '/api/experimental/pools/{}'.format(name)
|
||||
url = urljoin(self._api_base_url, endpoint)
|
||||
pool = self._request(url)
|
||||
return pool['pool'], pool['slots'], pool['description']
|
||||
|
||||
def get_pools(self):
|
||||
endpoint = '/api/experimental/pools'
|
||||
url = urljoin(self._api_base_url, endpoint)
|
||||
pools = self._request(url)
|
||||
return [(p['pool'], p['slots'], p['description']) for p in pools]
|
||||
|
||||
def create_pool(self, name, slots, description):
|
||||
endpoint = '/api/experimental/pools'
|
||||
url = urljoin(self._api_base_url, endpoint)
|
||||
pool = self._request(url, method='POST',
|
||||
json={
|
||||
'name': name,
|
||||
'slots': slots,
|
||||
'description': description,
|
||||
})
|
||||
return pool['pool'], pool['slots'], pool['description']
|
||||
|
||||
def delete_pool(self, name):
|
||||
endpoint = '/api/experimental/pools/{}'.format(name)
|
||||
url = urljoin(self._api_base_url, endpoint)
|
||||
pool = self._request(url, method='DELETE')
|
||||
return pool['pool'], pool['slots'], pool['description']
|
||||
|
|
|
@ -11,15 +11,33 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from airflow.api.client import api_client
|
||||
from airflow.api.common.experimental import pool
|
||||
from airflow.api.common.experimental import trigger_dag
|
||||
|
||||
|
||||
class Client(api_client.Client):
|
||||
"""Local API client implementation."""
|
||||
|
||||
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
|
||||
dr = trigger_dag.trigger_dag(dag_id=dag_id,
|
||||
run_id=run_id,
|
||||
conf=conf,
|
||||
execution_date=execution_date)
|
||||
return "Created {}".format(dr)
|
||||
|
||||
def get_pool(self, name):
|
||||
p = pool.get_pool(name=name)
|
||||
return p.pool, p.slots, p.description
|
||||
|
||||
def get_pools(self):
|
||||
return [(p.pool, p.slots, p.description) for p in pool.get_pools()]
|
||||
|
||||
def create_pool(self, name, slots, description):
|
||||
p = pool.create_pool(name=name, slots=slots, description=description)
|
||||
return p.pool, p.slots, p.description
|
||||
|
||||
def delete_pool(self, name):
|
||||
p = pool.delete_pool(name=name)
|
||||
return p.pool, p.slots, p.description
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import Pool
|
||||
from airflow.utils.db import provide_session
|
||||
|
||||
|
||||
class PoolBadRequest(AirflowException):
|
||||
status = 400
|
||||
|
||||
|
||||
class PoolNotFound(AirflowException):
|
||||
status = 404
|
||||
|
||||
|
||||
@provide_session
|
||||
def get_pool(name, session=None):
|
||||
"""Get pool by a given name."""
|
||||
if not (name and name.strip()):
|
||||
raise PoolBadRequest("Pool name shouldn't be empty")
|
||||
|
||||
pool = session.query(Pool).filter_by(pool=name).first()
|
||||
if pool is None:
|
||||
raise PoolNotFound("Pool '%s' doesn't exist" % name)
|
||||
|
||||
return pool
|
||||
|
||||
|
||||
@provide_session
|
||||
def get_pools(session=None):
|
||||
"""Get all pools."""
|
||||
return session.query(Pool).all()
|
||||
|
||||
|
||||
@provide_session
|
||||
def create_pool(name, slots, description, session=None):
|
||||
"""Create a pool with a given parameters."""
|
||||
if not (name and name.strip()):
|
||||
raise PoolBadRequest("Pool name shouldn't be empty")
|
||||
|
||||
try:
|
||||
slots = int(slots)
|
||||
except ValueError:
|
||||
raise PoolBadRequest("Bad value for `slots`: %s" % slots)
|
||||
|
||||
session.expire_on_commit = False
|
||||
pool = session.query(Pool).filter_by(pool=name).first()
|
||||
if pool is None:
|
||||
pool = Pool(pool=name, slots=slots, description=description)
|
||||
session.add(pool)
|
||||
else:
|
||||
pool.slots = slots
|
||||
pool.description = description
|
||||
|
||||
session.commit()
|
||||
|
||||
return pool
|
||||
|
||||
|
||||
@provide_session
|
||||
def delete_pool(name, session=None):
|
||||
"""Delete pool by a given name."""
|
||||
if not (name and name.strip()):
|
||||
raise PoolBadRequest("Pool name shouldn't be empty")
|
||||
|
||||
pool = session.query(Pool).filter_by(pool=name).first()
|
||||
if pool is None:
|
||||
raise PoolNotFound("Pool '%s' doesn't exist" % name)
|
||||
|
||||
session.delete(pool)
|
||||
session.commit()
|
||||
|
||||
return pool
|
|
@ -49,7 +49,7 @@ from airflow.exceptions import AirflowException
|
|||
from airflow.executors import GetDefaultExecutor
|
||||
from airflow.models import (DagModel, DagBag, TaskInstance,
|
||||
DagPickle, DagRun, Variable, DagStat,
|
||||
Pool, Connection)
|
||||
Connection)
|
||||
from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
|
||||
from airflow.utils import db as db_utils
|
||||
from airflow.utils import logging as logging_utils
|
||||
|
@ -187,40 +187,28 @@ def trigger_dag(args):
|
|||
|
||||
|
||||
def pool(args):
|
||||
session = settings.Session()
|
||||
if args.get or (args.set and args.set[0]) or args.delete:
|
||||
name = args.get or args.delete or args.set[0]
|
||||
pool = (
|
||||
session.query(Pool)
|
||||
.filter(Pool.pool == name)
|
||||
.first())
|
||||
if pool and args.get:
|
||||
print("{} ".format(pool))
|
||||
return
|
||||
elif not pool and (args.get or args.delete):
|
||||
print("No pool named {} found".format(name))
|
||||
elif not pool and args.set:
|
||||
pool = Pool(
|
||||
pool=name,
|
||||
slots=args.set[1],
|
||||
description=args.set[2])
|
||||
session.add(pool)
|
||||
session.commit()
|
||||
print("{} ".format(pool))
|
||||
elif pool and args.set:
|
||||
pool.slots = args.set[1]
|
||||
pool.description = args.set[2]
|
||||
session.commit()
|
||||
print("{} ".format(pool))
|
||||
return
|
||||
elif pool and args.delete:
|
||||
session.query(Pool).filter_by(pool=args.delete).delete()
|
||||
session.commit()
|
||||
print("Pool {} deleted".format(name))
|
||||
def _tabulate(pools):
|
||||
return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'],
|
||||
tablefmt="fancy_grid")
|
||||
|
||||
try:
|
||||
if args.get is not None:
|
||||
pools = [api_client.get_pool(name=args.get)]
|
||||
elif args.set:
|
||||
pools = [api_client.create_pool(name=args.set[0],
|
||||
slots=args.set[1],
|
||||
description=args.set[2])]
|
||||
elif args.delete:
|
||||
pools = [api_client.delete_pool(name=args.delete)]
|
||||
else:
|
||||
pools = api_client.get_pools()
|
||||
except (AirflowException, IOError) as err:
|
||||
logging.error(err)
|
||||
else:
|
||||
logging.info(_tabulate(pools=pools))
|
||||
|
||||
|
||||
def variables(args):
|
||||
|
||||
if args.get:
|
||||
try:
|
||||
var = Variable.get(args.get,
|
||||
|
|
|
@ -4395,6 +4395,14 @@ class Pool(Base):
|
|||
def __repr__(self):
|
||||
return self.pool
|
||||
|
||||
def to_json(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'pool': self.pool,
|
||||
'slots': self.slots,
|
||||
'description': self.description,
|
||||
}
|
||||
|
||||
@provide_session
|
||||
def used_slots(self, session):
|
||||
"""
|
||||
|
|
|
@ -11,10 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import airflow.api
|
||||
|
||||
from airflow.api.common.experimental import pool as pool_api
|
||||
from airflow.api.common.experimental import trigger_dag as trigger
|
||||
from airflow.api.common.experimental.get_task import get_task
|
||||
from airflow.api.common.experimental.get_task_instance import get_task_instance
|
||||
|
@ -96,7 +98,6 @@ def test():
|
|||
@requires_authentication
|
||||
def task_info(dag_id, task_id):
|
||||
"""Returns a JSON with a task's public instance variables. """
|
||||
|
||||
try:
|
||||
info = get_task(dag_id, task_id)
|
||||
except AirflowException as err:
|
||||
|
@ -169,4 +170,67 @@ def latest_dag_runs():
|
|||
'dag_run_url': url_for('airflow.graph', dag_id=dagrun.dag_id,
|
||||
execution_date=dagrun.execution_date)
|
||||
})
|
||||
return jsonify(items=payload) # old flask versions dont support jsonifying arrays
|
||||
return jsonify(items=payload) # old flask versions dont support jsonifying arrays
|
||||
|
||||
|
||||
@api_experimental.route('/pools/<string:name>', methods=['GET'])
|
||||
@requires_authentication
|
||||
def get_pool(name):
|
||||
"""Get pool by a given name."""
|
||||
try:
|
||||
pool = pool_api.get_pool(name=name)
|
||||
except AirflowException as e:
|
||||
_log.error(e)
|
||||
response = jsonify(error="{}".format(e))
|
||||
response.status_code = getattr(e, 'status', 500)
|
||||
return response
|
||||
else:
|
||||
return jsonify(pool.to_json())
|
||||
|
||||
|
||||
@api_experimental.route('/pools', methods=['GET'])
|
||||
@requires_authentication
|
||||
def get_pools():
|
||||
"""Get all pools."""
|
||||
try:
|
||||
pools = pool_api.get_pools()
|
||||
except AirflowException as e:
|
||||
_log.error(e)
|
||||
response = jsonify(error="{}".format(e))
|
||||
response.status_code = getattr(e, 'status', 500)
|
||||
return response
|
||||
else:
|
||||
return jsonify([p.to_json() for p in pools])
|
||||
|
||||
|
||||
@csrf.exempt
|
||||
@api_experimental.route('/pools', methods=['POST'])
|
||||
@requires_authentication
|
||||
def create_pool():
|
||||
"""Create a pool."""
|
||||
params = request.get_json(force=True)
|
||||
try:
|
||||
pool = pool_api.create_pool(**params)
|
||||
except AirflowException as e:
|
||||
_log.error(e)
|
||||
response = jsonify(error="{}".format(e))
|
||||
response.status_code = getattr(e, 'status', 500)
|
||||
return response
|
||||
else:
|
||||
return jsonify(pool.to_json())
|
||||
|
||||
|
||||
@csrf.exempt
|
||||
@api_experimental.route('/pools/<string:name>', methods=['DELETE'])
|
||||
@requires_authentication
|
||||
def delete_pool(name):
|
||||
"""Delete pool."""
|
||||
try:
|
||||
pool = pool_api.delete_pool(name=name)
|
||||
except AirflowException as e:
|
||||
_log.error(e)
|
||||
response = jsonify(error="{}".format(e))
|
||||
response.status_code = getattr(e, 'status', 500)
|
||||
return response
|
||||
else:
|
||||
return jsonify(pool.to_json())
|
||||
|
|
|
@ -11,9 +11,3 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .client import *
|
||||
from .common import *
|
||||
|
||||
|
|
|
@ -12,16 +12,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import unittest
|
||||
import datetime
|
||||
|
||||
from mock import patch
|
||||
|
||||
from airflow import AirflowException
|
||||
from airflow import models
|
||||
|
||||
from airflow.api.client.local_client import Client
|
||||
from airflow import models
|
||||
from airflow import settings
|
||||
from airflow.utils.state import State
|
||||
|
||||
EXECDATE = datetime.datetime.now()
|
||||
|
@ -53,8 +53,25 @@ def mock_datetime_now(target, dt):
|
|||
|
||||
|
||||
class TestLocalClient(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestLocalClient, cls).setUpClass()
|
||||
session = settings.Session()
|
||||
session.query(models.Pool).delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def setUp(self):
|
||||
super(TestLocalClient, self).setUp()
|
||||
self.client = Client(api_base_url=None, auth=None)
|
||||
self.session = settings.Session()
|
||||
|
||||
def tearDown(self):
|
||||
self.session.query(models.Pool).delete()
|
||||
self.session.commit()
|
||||
self.session.close()
|
||||
super(TestLocalClient, self).tearDown()
|
||||
|
||||
@patch.object(models.DAG, 'create_dagrun')
|
||||
def test_trigger_dag(self, mock):
|
||||
|
@ -104,4 +121,24 @@ class TestLocalClient(unittest.TestCase):
|
|||
external_trigger=True)
|
||||
mock.reset_mock()
|
||||
|
||||
# this is a unit test only, cannot verify existing dag run
|
||||
def test_get_pool(self):
|
||||
self.client.create_pool(name='foo', slots=1, description='')
|
||||
pool = self.client.get_pool(name='foo')
|
||||
self.assertEqual(pool, ('foo', 1, ''))
|
||||
|
||||
def test_get_pools(self):
|
||||
self.client.create_pool(name='foo1', slots=1, description='')
|
||||
self.client.create_pool(name='foo2', slots=2, description='')
|
||||
pools = sorted(self.client.get_pools(), key=lambda p: p[0])
|
||||
self.assertEqual(pools, [('foo1', 1, ''), ('foo2', 2, '')])
|
||||
|
||||
def test_create_pool(self):
|
||||
pool = self.client.create_pool(name='foo', slots=1, description='')
|
||||
self.assertEqual(pool, ('foo', 1, ''))
|
||||
self.assertEqual(self.session.query(models.Pool).count(), 1)
|
||||
|
||||
def test_delete_pool(self):
|
||||
self.client.create_pool(name='foo', slots=1, description='')
|
||||
self.assertEqual(self.session.query(models.Pool).count(), 1)
|
||||
self.client.delete_pool(name='foo')
|
||||
self.assertEqual(self.session.query(models.Pool).count(), 0)
|
|
@ -0,0 +1,13 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
|
@ -27,6 +26,7 @@ DEV_NULL = "/dev/null"
|
|||
|
||||
|
||||
class TestMarkTasks(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.dagbag = models.DagBag(include_examples=True)
|
||||
self.dag1 = self.dagbag.dags['test_example_bash_operator']
|
||||
|
@ -52,6 +52,16 @@ class TestMarkTasks(unittest.TestCase):
|
|||
|
||||
self.session = Session()
|
||||
|
||||
def tearDown(self):
|
||||
self.dag1.clear()
|
||||
self.dag2.clear()
|
||||
|
||||
# just to make sure we are fully cleaned up
|
||||
self.session.query(models.DagRun).delete()
|
||||
self.session.query(models.TaskInstance).delete()
|
||||
self.session.commit()
|
||||
self.session.close()
|
||||
|
||||
def snapshot_state(self, dag, execution_dates):
|
||||
TI = models.TaskInstance
|
||||
tis = self.session.query(TI).filter(
|
||||
|
@ -197,16 +207,6 @@ class TestMarkTasks(unittest.TestCase):
|
|||
self.verify_state(self.dag2, task_ids, [self.execution_dates[0]],
|
||||
State.SUCCESS, [])
|
||||
|
||||
def tearDown(self):
|
||||
self.dag1.clear()
|
||||
self.dag2.clear()
|
||||
|
||||
# just to make sure we are fully cleaned up
|
||||
self.session.query(models.DagRun).delete()
|
||||
self.session.query(models.TaskInstance).delete()
|
||||
self.session.commit()
|
||||
|
||||
self.session.close()
|
||||
|
||||
class TestMarkDAGRun(unittest.TestCase):
|
||||
def setUp(self):
|
|
@ -0,0 +1,132 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from airflow.api.common.experimental import pool as pool_api
|
||||
from airflow import models
|
||||
from airflow import settings
|
||||
|
||||
|
||||
class TestPool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestPool, self).setUp()
|
||||
self.session = settings.Session()
|
||||
self.pools = []
|
||||
for i in range(2):
|
||||
name = 'experimental_%s' % (i + 1)
|
||||
pool = models.Pool(
|
||||
pool=name,
|
||||
slots=i,
|
||||
description=name,
|
||||
)
|
||||
self.session.add(pool)
|
||||
self.pools.append(pool)
|
||||
self.session.commit()
|
||||
|
||||
def tearDown(self):
|
||||
self.session.query(models.Pool).delete()
|
||||
self.session.commit()
|
||||
self.session.close()
|
||||
super(TestPool, self).tearDown()
|
||||
|
||||
def test_get_pool(self):
|
||||
pool = pool_api.get_pool(name=self.pools[0].pool, session=self.session)
|
||||
self.assertEqual(pool.pool, self.pools[0].pool)
|
||||
|
||||
def test_get_pool_non_existing(self):
|
||||
self.assertRaisesRegexp(pool_api.PoolNotFound,
|
||||
"^Pool 'test' doesn't exist$",
|
||||
pool_api.get_pool,
|
||||
name='test',
|
||||
session=self.session)
|
||||
|
||||
def test_get_pool_bad_name(self):
|
||||
for name in ('', ' '):
|
||||
self.assertRaisesRegexp(pool_api.PoolBadRequest,
|
||||
"^Pool name shouldn't be empty$",
|
||||
pool_api.get_pool,
|
||||
name=name,
|
||||
session=self.session)
|
||||
|
||||
def test_get_pools(self):
|
||||
pools = sorted(pool_api.get_pools(session=self.session),
|
||||
key=lambda p: p.pool)
|
||||
self.assertEqual(pools[0].pool, self.pools[0].pool)
|
||||
self.assertEqual(pools[1].pool, self.pools[1].pool)
|
||||
|
||||
def test_create_pool(self):
|
||||
pool = pool_api.create_pool(name='foo',
|
||||
slots=5,
|
||||
description='',
|
||||
session=self.session)
|
||||
self.assertEqual(pool.pool, 'foo')
|
||||
self.assertEqual(pool.slots, 5)
|
||||
self.assertEqual(pool.description, '')
|
||||
self.assertEqual(self.session.query(models.Pool).count(), 3)
|
||||
|
||||
def test_create_pool_existing(self):
|
||||
pool = pool_api.create_pool(name=self.pools[0].pool,
|
||||
slots=5,
|
||||
description='',
|
||||
session=self.session)
|
||||
self.assertEqual(pool.pool, self.pools[0].pool)
|
||||
self.assertEqual(pool.slots, 5)
|
||||
self.assertEqual(pool.description, '')
|
||||
self.assertEqual(self.session.query(models.Pool).count(), 2)
|
||||
|
||||
def test_create_pool_bad_name(self):
|
||||
for name in ('', ' '):
|
||||
self.assertRaisesRegexp(pool_api.PoolBadRequest,
|
||||
"^Pool name shouldn't be empty$",
|
||||
pool_api.create_pool,
|
||||
name=name,
|
||||
slots=5,
|
||||
description='',
|
||||
session=self.session)
|
||||
|
||||
def test_create_pool_bad_slots(self):
|
||||
self.assertRaisesRegexp(pool_api.PoolBadRequest,
|
||||
"^Bad value for `slots`: foo$",
|
||||
pool_api.create_pool,
|
||||
name='foo',
|
||||
slots='foo',
|
||||
description='',
|
||||
session=self.session)
|
||||
|
||||
def test_delete_pool(self):
|
||||
pool = pool_api.delete_pool(name=self.pools[0].pool,
|
||||
session=self.session)
|
||||
self.assertEqual(pool.pool, self.pools[0].pool)
|
||||
self.assertEqual(self.session.query(models.Pool).count(), 1)
|
||||
|
||||
def test_delete_pool_non_existing(self):
|
||||
self.assertRaisesRegexp(pool_api.PoolNotFound,
|
||||
"^Pool 'test' doesn't exist$",
|
||||
pool_api.delete_pool,
|
||||
name='test',
|
||||
session=self.session)
|
||||
|
||||
def test_delete_pool_bad_name(self):
|
||||
for name in ('', ' '):
|
||||
self.assertRaisesRegexp(pool_api.PoolBadRequest,
|
||||
"^Pool name shouldn't be empty$",
|
||||
pool_api.delete_pool,
|
||||
name=name,
|
||||
session=self.session)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1062,12 +1062,34 @@ class CoreTest(unittest.TestCase):
|
|||
|
||||
|
||||
class CliTests(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(CliTests, cls).setUpClass()
|
||||
cls._cleanup()
|
||||
|
||||
def setUp(self):
|
||||
super(CliTests, self).setUp()
|
||||
configuration.load_test_config()
|
||||
app = application.create_app()
|
||||
app.config['TESTING'] = True
|
||||
self.parser = cli.CLIFactory.get_parser()
|
||||
self.dagbag = models.DagBag(dag_folder=DEV_NULL, include_examples=True)
|
||||
self.session = Session()
|
||||
|
||||
def tearDown(self):
|
||||
self._cleanup(session=self.session)
|
||||
super(CliTests, self).tearDown()
|
||||
|
||||
@staticmethod
|
||||
def _cleanup(session=None):
|
||||
if session is None:
|
||||
session = Session()
|
||||
|
||||
session.query(models.Pool).delete()
|
||||
session.query(models.Variable).delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def test_cli_list_dags(self):
|
||||
args = self.parser.parse_args(['list_dags', '--report'])
|
||||
|
@ -1100,8 +1122,8 @@ class CliTests(unittest.TestCase):
|
|||
cli.connections(self.parser.parse_args(['connections', '--list']))
|
||||
stdout = mock_stdout.getvalue()
|
||||
conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]]
|
||||
for ii, line in enumerate(stdout.split('\n'))
|
||||
if ii % 2 == 1]
|
||||
for ii, line in enumerate(stdout.split('\n'))
|
||||
if ii % 2 == 1]
|
||||
conns = [conn for conn in conns if len(conn) > 0]
|
||||
|
||||
# Assert that some of the connections are present in the output as
|
||||
|
@ -1365,14 +1387,27 @@ class CliTests(unittest.TestCase):
|
|||
'-c', 'NOT JSON'])
|
||||
)
|
||||
|
||||
def test_pool(self):
|
||||
# Checks if all subcommands are properly received
|
||||
cli.pool(self.parser.parse_args([
|
||||
'pool', '-s', 'foo', '1', '"my foo pool"']))
|
||||
cli.pool(self.parser.parse_args([
|
||||
'pool', '-g', 'foo']))
|
||||
cli.pool(self.parser.parse_args([
|
||||
'pool', '-x', 'foo']))
|
||||
def test_pool_create(self):
|
||||
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
|
||||
self.assertEqual(self.session.query(models.Pool).count(), 1)
|
||||
|
||||
def test_pool_get(self):
|
||||
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
|
||||
try:
|
||||
cli.pool(self.parser.parse_args(['pool', '-g', 'foo']))
|
||||
except Exception as e:
|
||||
self.fail("The 'pool -g foo' command raised unexpectedly: %s" % e)
|
||||
|
||||
def test_pool_delete(self):
|
||||
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
|
||||
cli.pool(self.parser.parse_args(['pool', '-x', 'foo']))
|
||||
self.assertEqual(self.session.query(models.Pool).count(), 0)
|
||||
|
||||
def test_pool_no_args(self):
|
||||
try:
|
||||
cli.pool(self.parser.parse_args(['pool']))
|
||||
except Exception as e:
|
||||
self.fail("The 'pool' command raised unexpectedly: %s" % e)
|
||||
|
||||
def test_variables(self):
|
||||
# Checks if all subcommands are properly received
|
||||
|
@ -1426,10 +1461,6 @@ class CliTests(unittest.TestCase):
|
|||
self.assertEqual('original', models.Variable.get('bar'))
|
||||
self.assertEqual('{"foo": "bar"}', models.Variable.get('foo'))
|
||||
|
||||
session = settings.Session()
|
||||
session.query(Variable).delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
os.remove('variables1.json')
|
||||
os.remove('variables2.json')
|
||||
|
||||
|
|
|
@ -19,23 +19,36 @@ from urllib.parse import quote_plus
|
|||
|
||||
from airflow import configuration
|
||||
from airflow.api.common.experimental.trigger_dag import trigger_dag
|
||||
from airflow.models import DagBag, DagRun, TaskInstance
|
||||
from airflow.models import DagBag, DagRun, Pool, TaskInstance
|
||||
from airflow.settings import Session
|
||||
from airflow.www import app as application
|
||||
|
||||
|
||||
class ApiExperimentalTests(unittest.TestCase):
|
||||
class TestApiExperimental(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
configuration.load_test_config()
|
||||
app = application.create_app(testing=True)
|
||||
self.app = app.test_client()
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestApiExperimental, cls).setUpClass()
|
||||
session = Session()
|
||||
session.query(DagRun).delete()
|
||||
session.query(TaskInstance).delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def setUp(self):
|
||||
super(TestApiExperimental, self).setUp()
|
||||
configuration.load_test_config()
|
||||
app = application.create_app(testing=True)
|
||||
self.app = app.test_client()
|
||||
|
||||
def tearDown(self):
|
||||
session = Session()
|
||||
session.query(DagRun).delete()
|
||||
session.query(TaskInstance).delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
super(TestApiExperimental, self).tearDown()
|
||||
|
||||
def test_task_info(self):
|
||||
url_template = '/api/experimental/dags/{}/tasks/{}'
|
||||
|
||||
|
@ -62,7 +75,7 @@ class ApiExperimentalTests(unittest.TestCase):
|
|||
url_template = '/api/experimental/dags/{}/dag_runs'
|
||||
response = self.app.post(
|
||||
url_template.format('example_bash_operator'),
|
||||
data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())),
|
||||
data=json.dumps({'run_id': 'my_run' + datetime.now().isoformat()}),
|
||||
content_type="application/json"
|
||||
)
|
||||
|
||||
|
@ -70,7 +83,7 @@ class ApiExperimentalTests(unittest.TestCase):
|
|||
|
||||
response = self.app.post(
|
||||
url_template.format('does_not_exist_dag'),
|
||||
data=json.dumps(dict()),
|
||||
data=json.dumps({}),
|
||||
content_type="application/json"
|
||||
)
|
||||
self.assertEqual(404, response.status_code)
|
||||
|
@ -88,7 +101,7 @@ class ApiExperimentalTests(unittest.TestCase):
|
|||
# Test Correct execution
|
||||
response = self.app.post(
|
||||
url_template.format(dag_id),
|
||||
data=json.dumps(dict(execution_date=datetime_string)),
|
||||
data=json.dumps({'execution_date': datetime_string}),
|
||||
content_type="application/json"
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
@ -103,7 +116,7 @@ class ApiExperimentalTests(unittest.TestCase):
|
|||
# Test error for nonexistent dag
|
||||
response = self.app.post(
|
||||
url_template.format('does_not_exist_dag'),
|
||||
data=json.dumps(dict(execution_date=execution_date.isoformat())),
|
||||
data=json.dumps({'execution_date': execution_date.isoformat()}),
|
||||
content_type="application/json"
|
||||
)
|
||||
self.assertEqual(404, response.status_code)
|
||||
|
@ -111,7 +124,7 @@ class ApiExperimentalTests(unittest.TestCase):
|
|||
# Test error for bad datetime format
|
||||
response = self.app.post(
|
||||
url_template.format(dag_id),
|
||||
data=json.dumps(dict(execution_date='not_a_datetime')),
|
||||
data=json.dumps({'execution_date': 'not_a_datetime'}),
|
||||
content_type="application/json"
|
||||
)
|
||||
self.assertEqual(400, response.status_code)
|
||||
|
@ -122,7 +135,9 @@ class ApiExperimentalTests(unittest.TestCase):
|
|||
task_id = 'also_run_this'
|
||||
execution_date = datetime.now().replace(microsecond=0)
|
||||
datetime_string = quote_plus(execution_date.isoformat())
|
||||
wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat())
|
||||
wrong_datetime_string = quote_plus(
|
||||
datetime(1990, 1, 1, 1, 1, 1).isoformat()
|
||||
)
|
||||
|
||||
# Create DagRun
|
||||
trigger_dag(dag_id=dag_id,
|
||||
|
@ -139,7 +154,8 @@ class ApiExperimentalTests(unittest.TestCase):
|
|||
|
||||
# Test error for nonexistent dag
|
||||
response = self.app.get(
|
||||
url_template.format('does_not_exist_dag', datetime_string, task_id),
|
||||
url_template.format('does_not_exist_dag', datetime_string,
|
||||
task_id),
|
||||
)
|
||||
self.assertEqual(404, response.status_code)
|
||||
self.assertIn('error', response.data.decode('utf-8'))
|
||||
|
@ -164,3 +180,122 @@ class ApiExperimentalTests(unittest.TestCase):
|
|||
)
|
||||
self.assertEqual(400, response.status_code)
|
||||
self.assertIn('error', response.data.decode('utf-8'))
|
||||
|
||||
|
||||
class TestPoolApiExperimental(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestPoolApiExperimental, cls).setUpClass()
|
||||
session = Session()
|
||||
session.query(Pool).delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def setUp(self):
|
||||
super(TestPoolApiExperimental, self).setUp()
|
||||
configuration.load_test_config()
|
||||
app = application.create_app(testing=True)
|
||||
self.app = app.test_client()
|
||||
self.session = Session()
|
||||
self.pools = []
|
||||
for i in range(2):
|
||||
name = 'experimental_%s' % (i + 1)
|
||||
pool = Pool(
|
||||
pool=name,
|
||||
slots=i,
|
||||
description=name,
|
||||
)
|
||||
self.session.add(pool)
|
||||
self.pools.append(pool)
|
||||
self.session.commit()
|
||||
self.pool = self.pools[0]
|
||||
|
||||
def tearDown(self):
|
||||
self.session.query(Pool).delete()
|
||||
self.session.commit()
|
||||
self.session.close()
|
||||
super(TestPoolApiExperimental, self).tearDown()
|
||||
|
||||
def _get_pool_count(self):
|
||||
response = self.app.get('/api/experimental/pools')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
return len(json.loads(response.data.decode('utf-8')))
|
||||
|
||||
def test_get_pool(self):
|
||||
response = self.app.get(
|
||||
'/api/experimental/pools/{}'.format(self.pool.pool),
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(json.loads(response.data.decode('utf-8')),
|
||||
self.pool.to_json())
|
||||
|
||||
def test_get_pool_non_existing(self):
|
||||
response = self.app.get('/api/experimental/pools/foo')
|
||||
self.assertEqual(response.status_code, 404)
|
||||
self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
|
||||
"Pool 'foo' doesn't exist")
|
||||
|
||||
def test_get_pools(self):
|
||||
response = self.app.get('/api/experimental/pools')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
pools = json.loads(response.data.decode('utf-8'))
|
||||
self.assertEqual(len(pools), 2)
|
||||
for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])):
|
||||
self.assertDictEqual(pool, self.pools[i].to_json())
|
||||
|
||||
def test_create_pool(self):
|
||||
response = self.app.post(
|
||||
'/api/experimental/pools',
|
||||
data=json.dumps({
|
||||
'name': 'foo',
|
||||
'slots': 1,
|
||||
'description': '',
|
||||
}),
|
||||
content_type='application/json',
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
pool = json.loads(response.data.decode('utf-8'))
|
||||
self.assertEqual(pool['pool'], 'foo')
|
||||
self.assertEqual(pool['slots'], 1)
|
||||
self.assertEqual(pool['description'], '')
|
||||
self.assertEqual(self._get_pool_count(), 3)
|
||||
|
||||
def test_create_pool_with_bad_name(self):
|
||||
for name in ('', ' '):
|
||||
response = self.app.post(
|
||||
'/api/experimental/pools',
|
||||
data=json.dumps({
|
||||
'name': name,
|
||||
'slots': 1,
|
||||
'description': '',
|
||||
}),
|
||||
content_type='application/json',
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
json.loads(response.data.decode('utf-8'))['error'],
|
||||
"Pool name shouldn't be empty",
|
||||
)
|
||||
self.assertEqual(self._get_pool_count(), 2)
|
||||
|
||||
def test_delete_pool(self):
|
||||
response = self.app.delete(
|
||||
'/api/experimental/pools/{}'.format(self.pool.pool),
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(json.loads(response.data.decode('utf-8')),
|
||||
self.pool.to_json())
|
||||
self.assertEqual(self._get_pool_count(), 1)
|
||||
|
||||
def test_delete_pool_non_existing(self):
|
||||
response = self.app.delete(
|
||||
'/api/experimental/pools/foo',
|
||||
)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
|
||||
"Pool 'foo' doesn't exist")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче