[AIRFLOW-1275] Put 'airflow pool' into API
Closes #2346 from skudriashev/airflow-1275
This commit is contained in:
Родитель
a45e2d1888
Коммит
9958aa9d53
|
@ -14,17 +14,47 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client(object):
|
||||||
|
"""Base API client for all API clients."""
|
||||||
|
|
||||||
def __init__(self, api_base_url, auth):
|
def __init__(self, api_base_url, auth):
|
||||||
self._api_base_url = api_base_url
|
self._api_base_url = api_base_url
|
||||||
self._auth = auth
|
self._auth = auth
|
||||||
|
|
||||||
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
|
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
|
||||||
"""
|
"""Create a dag run for the specified dag.
|
||||||
Creates a dag run for the specified dag
|
|
||||||
:param dag_id:
|
:param dag_id:
|
||||||
:param run_id:
|
:param run_id:
|
||||||
:param conf:
|
:param conf:
|
||||||
|
:param execution_date:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_pool(self, name):
|
||||||
|
"""Get pool.
|
||||||
|
|
||||||
|
:param name: pool name
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_pools(self):
|
||||||
|
"""Get all pools."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def create_pool(self, name, slots, description):
|
||||||
|
"""Create a pool.
|
||||||
|
|
||||||
|
:param name: pool name
|
||||||
|
:param slots: pool slots amount
|
||||||
|
:param description: pool description
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def delete_pool(self, name):
|
||||||
|
"""Delete pool.
|
||||||
|
|
||||||
|
:param name: pool name
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
@ -11,30 +11,70 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
|
||||||
from future.moves.urllib.parse import urljoin
|
from future.moves.urllib.parse import urljoin
|
||||||
|
import requests
|
||||||
|
|
||||||
from airflow.api.client import api_client
|
from airflow.api.client import api_client
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
class Client(api_client.Client):
|
class Client(api_client.Client):
|
||||||
|
"""Json API client implementation."""
|
||||||
|
|
||||||
|
def _request(self, url, method='GET', json=None):
|
||||||
|
params = {
|
||||||
|
'url': url,
|
||||||
|
'auth': self._auth,
|
||||||
|
}
|
||||||
|
if json is not None:
|
||||||
|
params['json'] = json
|
||||||
|
|
||||||
|
resp = getattr(requests, method.lower())(**params)
|
||||||
|
if not resp.ok:
|
||||||
|
try:
|
||||||
|
data = resp.json()
|
||||||
|
except Exception:
|
||||||
|
data = {}
|
||||||
|
raise IOError(data.get('error', 'Server error'))
|
||||||
|
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
|
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
|
||||||
endpoint = '/api/experimental/dags/{}/dag_runs'.format(dag_id)
|
endpoint = '/api/experimental/dags/{}/dag_runs'.format(dag_id)
|
||||||
url = urljoin(self._api_base_url, endpoint)
|
url = urljoin(self._api_base_url, endpoint)
|
||||||
|
data = self._request(url, method='POST',
|
||||||
resp = requests.post(url,
|
|
||||||
auth=self._auth,
|
|
||||||
json={
|
json={
|
||||||
"run_id": run_id,
|
"run_id": run_id,
|
||||||
"conf": conf,
|
"conf": conf,
|
||||||
"execution_date": execution_date,
|
"execution_date": execution_date,
|
||||||
})
|
})
|
||||||
|
|
||||||
if not resp.ok:
|
|
||||||
raise IOError()
|
|
||||||
|
|
||||||
data = resp.json()
|
|
||||||
|
|
||||||
return data['message']
|
return data['message']
|
||||||
|
|
||||||
|
def get_pool(self, name):
|
||||||
|
endpoint = '/api/experimental/pools/{}'.format(name)
|
||||||
|
url = urljoin(self._api_base_url, endpoint)
|
||||||
|
pool = self._request(url)
|
||||||
|
return pool['pool'], pool['slots'], pool['description']
|
||||||
|
|
||||||
|
def get_pools(self):
|
||||||
|
endpoint = '/api/experimental/pools'
|
||||||
|
url = urljoin(self._api_base_url, endpoint)
|
||||||
|
pools = self._request(url)
|
||||||
|
return [(p['pool'], p['slots'], p['description']) for p in pools]
|
||||||
|
|
||||||
|
def create_pool(self, name, slots, description):
|
||||||
|
endpoint = '/api/experimental/pools'
|
||||||
|
url = urljoin(self._api_base_url, endpoint)
|
||||||
|
pool = self._request(url, method='POST',
|
||||||
|
json={
|
||||||
|
'name': name,
|
||||||
|
'slots': slots,
|
||||||
|
'description': description,
|
||||||
|
})
|
||||||
|
return pool['pool'], pool['slots'], pool['description']
|
||||||
|
|
||||||
|
def delete_pool(self, name):
|
||||||
|
endpoint = '/api/experimental/pools/{}'.format(name)
|
||||||
|
url = urljoin(self._api_base_url, endpoint)
|
||||||
|
pool = self._request(url, method='DELETE')
|
||||||
|
return pool['pool'], pool['slots'], pool['description']
|
||||||
|
|
|
@ -11,15 +11,33 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
|
||||||
from airflow.api.client import api_client
|
from airflow.api.client import api_client
|
||||||
|
from airflow.api.common.experimental import pool
|
||||||
from airflow.api.common.experimental import trigger_dag
|
from airflow.api.common.experimental import trigger_dag
|
||||||
|
|
||||||
|
|
||||||
class Client(api_client.Client):
|
class Client(api_client.Client):
|
||||||
|
"""Local API client implementation."""
|
||||||
|
|
||||||
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
|
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
|
||||||
dr = trigger_dag.trigger_dag(dag_id=dag_id,
|
dr = trigger_dag.trigger_dag(dag_id=dag_id,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
execution_date=execution_date)
|
execution_date=execution_date)
|
||||||
return "Created {}".format(dr)
|
return "Created {}".format(dr)
|
||||||
|
|
||||||
|
def get_pool(self, name):
|
||||||
|
p = pool.get_pool(name=name)
|
||||||
|
return p.pool, p.slots, p.description
|
||||||
|
|
||||||
|
def get_pools(self):
|
||||||
|
return [(p.pool, p.slots, p.description) for p in pool.get_pools()]
|
||||||
|
|
||||||
|
def create_pool(self, name, slots, description):
|
||||||
|
p = pool.create_pool(name=name, slots=slots, description=description)
|
||||||
|
return p.pool, p.slots, p.description
|
||||||
|
|
||||||
|
def delete_pool(self, name):
|
||||||
|
p = pool.delete_pool(name=name)
|
||||||
|
return p.pool, p.slots, p.description
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from airflow.exceptions import AirflowException
|
||||||
|
from airflow.models import Pool
|
||||||
|
from airflow.utils.db import provide_session
|
||||||
|
|
||||||
|
|
||||||
|
class PoolBadRequest(AirflowException):
|
||||||
|
status = 400
|
||||||
|
|
||||||
|
|
||||||
|
class PoolNotFound(AirflowException):
|
||||||
|
status = 404
|
||||||
|
|
||||||
|
|
||||||
|
@provide_session
|
||||||
|
def get_pool(name, session=None):
|
||||||
|
"""Get pool by a given name."""
|
||||||
|
if not (name and name.strip()):
|
||||||
|
raise PoolBadRequest("Pool name shouldn't be empty")
|
||||||
|
|
||||||
|
pool = session.query(Pool).filter_by(pool=name).first()
|
||||||
|
if pool is None:
|
||||||
|
raise PoolNotFound("Pool '%s' doesn't exist" % name)
|
||||||
|
|
||||||
|
return pool
|
||||||
|
|
||||||
|
|
||||||
|
@provide_session
|
||||||
|
def get_pools(session=None):
|
||||||
|
"""Get all pools."""
|
||||||
|
return session.query(Pool).all()
|
||||||
|
|
||||||
|
|
||||||
|
@provide_session
|
||||||
|
def create_pool(name, slots, description, session=None):
|
||||||
|
"""Create a pool with a given parameters."""
|
||||||
|
if not (name and name.strip()):
|
||||||
|
raise PoolBadRequest("Pool name shouldn't be empty")
|
||||||
|
|
||||||
|
try:
|
||||||
|
slots = int(slots)
|
||||||
|
except ValueError:
|
||||||
|
raise PoolBadRequest("Bad value for `slots`: %s" % slots)
|
||||||
|
|
||||||
|
session.expire_on_commit = False
|
||||||
|
pool = session.query(Pool).filter_by(pool=name).first()
|
||||||
|
if pool is None:
|
||||||
|
pool = Pool(pool=name, slots=slots, description=description)
|
||||||
|
session.add(pool)
|
||||||
|
else:
|
||||||
|
pool.slots = slots
|
||||||
|
pool.description = description
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return pool
|
||||||
|
|
||||||
|
|
||||||
|
@provide_session
|
||||||
|
def delete_pool(name, session=None):
|
||||||
|
"""Delete pool by a given name."""
|
||||||
|
if not (name and name.strip()):
|
||||||
|
raise PoolBadRequest("Pool name shouldn't be empty")
|
||||||
|
|
||||||
|
pool = session.query(Pool).filter_by(pool=name).first()
|
||||||
|
if pool is None:
|
||||||
|
raise PoolNotFound("Pool '%s' doesn't exist" % name)
|
||||||
|
|
||||||
|
session.delete(pool)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return pool
|
|
@ -49,7 +49,7 @@ from airflow.exceptions import AirflowException
|
||||||
from airflow.executors import GetDefaultExecutor
|
from airflow.executors import GetDefaultExecutor
|
||||||
from airflow.models import (DagModel, DagBag, TaskInstance,
|
from airflow.models import (DagModel, DagBag, TaskInstance,
|
||||||
DagPickle, DagRun, Variable, DagStat,
|
DagPickle, DagRun, Variable, DagStat,
|
||||||
Pool, Connection)
|
Connection)
|
||||||
from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
|
from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
|
||||||
from airflow.utils import db as db_utils
|
from airflow.utils import db as db_utils
|
||||||
from airflow.utils import logging as logging_utils
|
from airflow.utils import logging as logging_utils
|
||||||
|
@ -187,40 +187,28 @@ def trigger_dag(args):
|
||||||
|
|
||||||
|
|
||||||
def pool(args):
|
def pool(args):
|
||||||
session = settings.Session()
|
def _tabulate(pools):
|
||||||
if args.get or (args.set and args.set[0]) or args.delete:
|
return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'],
|
||||||
name = args.get or args.delete or args.set[0]
|
tablefmt="fancy_grid")
|
||||||
pool = (
|
|
||||||
session.query(Pool)
|
try:
|
||||||
.filter(Pool.pool == name)
|
if args.get is not None:
|
||||||
.first())
|
pools = [api_client.get_pool(name=args.get)]
|
||||||
if pool and args.get:
|
elif args.set:
|
||||||
print("{} ".format(pool))
|
pools = [api_client.create_pool(name=args.set[0],
|
||||||
return
|
slots=args.set[1],
|
||||||
elif not pool and (args.get or args.delete):
|
description=args.set[2])]
|
||||||
print("No pool named {} found".format(name))
|
elif args.delete:
|
||||||
elif not pool and args.set:
|
pools = [api_client.delete_pool(name=args.delete)]
|
||||||
pool = Pool(
|
else:
|
||||||
pool=name,
|
pools = api_client.get_pools()
|
||||||
slots=args.set[1],
|
except (AirflowException, IOError) as err:
|
||||||
description=args.set[2])
|
logging.error(err)
|
||||||
session.add(pool)
|
else:
|
||||||
session.commit()
|
logging.info(_tabulate(pools=pools))
|
||||||
print("{} ".format(pool))
|
|
||||||
elif pool and args.set:
|
|
||||||
pool.slots = args.set[1]
|
|
||||||
pool.description = args.set[2]
|
|
||||||
session.commit()
|
|
||||||
print("{} ".format(pool))
|
|
||||||
return
|
|
||||||
elif pool and args.delete:
|
|
||||||
session.query(Pool).filter_by(pool=args.delete).delete()
|
|
||||||
session.commit()
|
|
||||||
print("Pool {} deleted".format(name))
|
|
||||||
|
|
||||||
|
|
||||||
def variables(args):
|
def variables(args):
|
||||||
|
|
||||||
if args.get:
|
if args.get:
|
||||||
try:
|
try:
|
||||||
var = Variable.get(args.get,
|
var = Variable.get(args.get,
|
||||||
|
|
|
@ -4395,6 +4395,14 @@ class Pool(Base):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.pool
|
return self.pool
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return {
|
||||||
|
'id': self.id,
|
||||||
|
'pool': self.pool,
|
||||||
|
'slots': self.slots,
|
||||||
|
'description': self.description,
|
||||||
|
}
|
||||||
|
|
||||||
@provide_session
|
@provide_session
|
||||||
def used_slots(self, session):
|
def used_slots(self, session):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -11,10 +11,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import airflow.api
|
import airflow.api
|
||||||
|
|
||||||
|
from airflow.api.common.experimental import pool as pool_api
|
||||||
from airflow.api.common.experimental import trigger_dag as trigger
|
from airflow.api.common.experimental import trigger_dag as trigger
|
||||||
from airflow.api.common.experimental.get_task import get_task
|
from airflow.api.common.experimental.get_task import get_task
|
||||||
from airflow.api.common.experimental.get_task_instance import get_task_instance
|
from airflow.api.common.experimental.get_task_instance import get_task_instance
|
||||||
|
@ -96,7 +98,6 @@ def test():
|
||||||
@requires_authentication
|
@requires_authentication
|
||||||
def task_info(dag_id, task_id):
|
def task_info(dag_id, task_id):
|
||||||
"""Returns a JSON with a task's public instance variables. """
|
"""Returns a JSON with a task's public instance variables. """
|
||||||
|
|
||||||
try:
|
try:
|
||||||
info = get_task(dag_id, task_id)
|
info = get_task(dag_id, task_id)
|
||||||
except AirflowException as err:
|
except AirflowException as err:
|
||||||
|
@ -169,4 +170,67 @@ def latest_dag_runs():
|
||||||
'dag_run_url': url_for('airflow.graph', dag_id=dagrun.dag_id,
|
'dag_run_url': url_for('airflow.graph', dag_id=dagrun.dag_id,
|
||||||
execution_date=dagrun.execution_date)
|
execution_date=dagrun.execution_date)
|
||||||
})
|
})
|
||||||
return jsonify(items=payload) # old flask versions dont support jsonifying arrays
|
return jsonify(items=payload) # old flask versions dont support jsonifying arrays
|
||||||
|
|
||||||
|
|
||||||
|
@api_experimental.route('/pools/<string:name>', methods=['GET'])
|
||||||
|
@requires_authentication
|
||||||
|
def get_pool(name):
|
||||||
|
"""Get pool by a given name."""
|
||||||
|
try:
|
||||||
|
pool = pool_api.get_pool(name=name)
|
||||||
|
except AirflowException as e:
|
||||||
|
_log.error(e)
|
||||||
|
response = jsonify(error="{}".format(e))
|
||||||
|
response.status_code = getattr(e, 'status', 500)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
return jsonify(pool.to_json())
|
||||||
|
|
||||||
|
|
||||||
|
@api_experimental.route('/pools', methods=['GET'])
|
||||||
|
@requires_authentication
|
||||||
|
def get_pools():
|
||||||
|
"""Get all pools."""
|
||||||
|
try:
|
||||||
|
pools = pool_api.get_pools()
|
||||||
|
except AirflowException as e:
|
||||||
|
_log.error(e)
|
||||||
|
response = jsonify(error="{}".format(e))
|
||||||
|
response.status_code = getattr(e, 'status', 500)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
return jsonify([p.to_json() for p in pools])
|
||||||
|
|
||||||
|
|
||||||
|
@csrf.exempt
|
||||||
|
@api_experimental.route('/pools', methods=['POST'])
|
||||||
|
@requires_authentication
|
||||||
|
def create_pool():
|
||||||
|
"""Create a pool."""
|
||||||
|
params = request.get_json(force=True)
|
||||||
|
try:
|
||||||
|
pool = pool_api.create_pool(**params)
|
||||||
|
except AirflowException as e:
|
||||||
|
_log.error(e)
|
||||||
|
response = jsonify(error="{}".format(e))
|
||||||
|
response.status_code = getattr(e, 'status', 500)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
return jsonify(pool.to_json())
|
||||||
|
|
||||||
|
|
||||||
|
@csrf.exempt
|
||||||
|
@api_experimental.route('/pools/<string:name>', methods=['DELETE'])
|
||||||
|
@requires_authentication
|
||||||
|
def delete_pool(name):
|
||||||
|
"""Delete pool."""
|
||||||
|
try:
|
||||||
|
pool = pool_api.delete_pool(name=name)
|
||||||
|
except AirflowException as e:
|
||||||
|
_log.error(e)
|
||||||
|
response = jsonify(error="{}".format(e))
|
||||||
|
response.status_code = getattr(e, 'status', 500)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
return jsonify(pool.to_json())
|
||||||
|
|
|
@ -11,9 +11,3 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
from .client import *
|
|
||||||
from .common import *
|
|
||||||
|
|
||||||
|
|
|
@ -12,16 +12,16 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import datetime
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
import datetime
|
|
||||||
|
|
||||||
from mock import patch
|
from mock import patch
|
||||||
|
|
||||||
from airflow import AirflowException
|
from airflow import AirflowException
|
||||||
from airflow import models
|
|
||||||
|
|
||||||
from airflow.api.client.local_client import Client
|
from airflow.api.client.local_client import Client
|
||||||
|
from airflow import models
|
||||||
|
from airflow import settings
|
||||||
from airflow.utils.state import State
|
from airflow.utils.state import State
|
||||||
|
|
||||||
EXECDATE = datetime.datetime.now()
|
EXECDATE = datetime.datetime.now()
|
||||||
|
@ -53,8 +53,25 @@ def mock_datetime_now(target, dt):
|
||||||
|
|
||||||
|
|
||||||
class TestLocalClient(unittest.TestCase):
|
class TestLocalClient(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super(TestLocalClient, cls).setUpClass()
|
||||||
|
session = settings.Session()
|
||||||
|
session.query(models.Pool).delete()
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super(TestLocalClient, self).setUp()
|
||||||
self.client = Client(api_base_url=None, auth=None)
|
self.client = Client(api_base_url=None, auth=None)
|
||||||
|
self.session = settings.Session()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.session.query(models.Pool).delete()
|
||||||
|
self.session.commit()
|
||||||
|
self.session.close()
|
||||||
|
super(TestLocalClient, self).tearDown()
|
||||||
|
|
||||||
@patch.object(models.DAG, 'create_dagrun')
|
@patch.object(models.DAG, 'create_dagrun')
|
||||||
def test_trigger_dag(self, mock):
|
def test_trigger_dag(self, mock):
|
||||||
|
@ -104,4 +121,24 @@ class TestLocalClient(unittest.TestCase):
|
||||||
external_trigger=True)
|
external_trigger=True)
|
||||||
mock.reset_mock()
|
mock.reset_mock()
|
||||||
|
|
||||||
# this is a unit test only, cannot verify existing dag run
|
def test_get_pool(self):
|
||||||
|
self.client.create_pool(name='foo', slots=1, description='')
|
||||||
|
pool = self.client.get_pool(name='foo')
|
||||||
|
self.assertEqual(pool, ('foo', 1, ''))
|
||||||
|
|
||||||
|
def test_get_pools(self):
|
||||||
|
self.client.create_pool(name='foo1', slots=1, description='')
|
||||||
|
self.client.create_pool(name='foo2', slots=2, description='')
|
||||||
|
pools = sorted(self.client.get_pools(), key=lambda p: p[0])
|
||||||
|
self.assertEqual(pools, [('foo1', 1, ''), ('foo2', 2, '')])
|
||||||
|
|
||||||
|
def test_create_pool(self):
|
||||||
|
pool = self.client.create_pool(name='foo', slots=1, description='')
|
||||||
|
self.assertEqual(pool, ('foo', 1, ''))
|
||||||
|
self.assertEqual(self.session.query(models.Pool).count(), 1)
|
||||||
|
|
||||||
|
def test_delete_pool(self):
|
||||||
|
self.client.create_pool(name='foo', slots=1, description='')
|
||||||
|
self.assertEqual(self.session.query(models.Pool).count(), 1)
|
||||||
|
self.client.delete_pool(name='foo')
|
||||||
|
self.assertEqual(self.session.query(models.Pool).count(), 0)
|
|
@ -0,0 +1,13 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
|
@ -11,7 +11,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
@ -27,6 +26,7 @@ DEV_NULL = "/dev/null"
|
||||||
|
|
||||||
|
|
||||||
class TestMarkTasks(unittest.TestCase):
|
class TestMarkTasks(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.dagbag = models.DagBag(include_examples=True)
|
self.dagbag = models.DagBag(include_examples=True)
|
||||||
self.dag1 = self.dagbag.dags['test_example_bash_operator']
|
self.dag1 = self.dagbag.dags['test_example_bash_operator']
|
||||||
|
@ -52,6 +52,16 @@ class TestMarkTasks(unittest.TestCase):
|
||||||
|
|
||||||
self.session = Session()
|
self.session = Session()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.dag1.clear()
|
||||||
|
self.dag2.clear()
|
||||||
|
|
||||||
|
# just to make sure we are fully cleaned up
|
||||||
|
self.session.query(models.DagRun).delete()
|
||||||
|
self.session.query(models.TaskInstance).delete()
|
||||||
|
self.session.commit()
|
||||||
|
self.session.close()
|
||||||
|
|
||||||
def snapshot_state(self, dag, execution_dates):
|
def snapshot_state(self, dag, execution_dates):
|
||||||
TI = models.TaskInstance
|
TI = models.TaskInstance
|
||||||
tis = self.session.query(TI).filter(
|
tis = self.session.query(TI).filter(
|
||||||
|
@ -197,16 +207,6 @@ class TestMarkTasks(unittest.TestCase):
|
||||||
self.verify_state(self.dag2, task_ids, [self.execution_dates[0]],
|
self.verify_state(self.dag2, task_ids, [self.execution_dates[0]],
|
||||||
State.SUCCESS, [])
|
State.SUCCESS, [])
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.dag1.clear()
|
|
||||||
self.dag2.clear()
|
|
||||||
|
|
||||||
# just to make sure we are fully cleaned up
|
|
||||||
self.session.query(models.DagRun).delete()
|
|
||||||
self.session.query(models.TaskInstance).delete()
|
|
||||||
self.session.commit()
|
|
||||||
|
|
||||||
self.session.close()
|
|
||||||
|
|
||||||
class TestMarkDAGRun(unittest.TestCase):
|
class TestMarkDAGRun(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
|
@ -0,0 +1,132 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from airflow.api.common.experimental import pool as pool_api
|
||||||
|
from airflow import models
|
||||||
|
from airflow import settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestPool(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(TestPool, self).setUp()
|
||||||
|
self.session = settings.Session()
|
||||||
|
self.pools = []
|
||||||
|
for i in range(2):
|
||||||
|
name = 'experimental_%s' % (i + 1)
|
||||||
|
pool = models.Pool(
|
||||||
|
pool=name,
|
||||||
|
slots=i,
|
||||||
|
description=name,
|
||||||
|
)
|
||||||
|
self.session.add(pool)
|
||||||
|
self.pools.append(pool)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.session.query(models.Pool).delete()
|
||||||
|
self.session.commit()
|
||||||
|
self.session.close()
|
||||||
|
super(TestPool, self).tearDown()
|
||||||
|
|
||||||
|
def test_get_pool(self):
|
||||||
|
pool = pool_api.get_pool(name=self.pools[0].pool, session=self.session)
|
||||||
|
self.assertEqual(pool.pool, self.pools[0].pool)
|
||||||
|
|
||||||
|
def test_get_pool_non_existing(self):
|
||||||
|
self.assertRaisesRegexp(pool_api.PoolNotFound,
|
||||||
|
"^Pool 'test' doesn't exist$",
|
||||||
|
pool_api.get_pool,
|
||||||
|
name='test',
|
||||||
|
session=self.session)
|
||||||
|
|
||||||
|
def test_get_pool_bad_name(self):
|
||||||
|
for name in ('', ' '):
|
||||||
|
self.assertRaisesRegexp(pool_api.PoolBadRequest,
|
||||||
|
"^Pool name shouldn't be empty$",
|
||||||
|
pool_api.get_pool,
|
||||||
|
name=name,
|
||||||
|
session=self.session)
|
||||||
|
|
||||||
|
def test_get_pools(self):
|
||||||
|
pools = sorted(pool_api.get_pools(session=self.session),
|
||||||
|
key=lambda p: p.pool)
|
||||||
|
self.assertEqual(pools[0].pool, self.pools[0].pool)
|
||||||
|
self.assertEqual(pools[1].pool, self.pools[1].pool)
|
||||||
|
|
||||||
|
def test_create_pool(self):
|
||||||
|
pool = pool_api.create_pool(name='foo',
|
||||||
|
slots=5,
|
||||||
|
description='',
|
||||||
|
session=self.session)
|
||||||
|
self.assertEqual(pool.pool, 'foo')
|
||||||
|
self.assertEqual(pool.slots, 5)
|
||||||
|
self.assertEqual(pool.description, '')
|
||||||
|
self.assertEqual(self.session.query(models.Pool).count(), 3)
|
||||||
|
|
||||||
|
def test_create_pool_existing(self):
|
||||||
|
pool = pool_api.create_pool(name=self.pools[0].pool,
|
||||||
|
slots=5,
|
||||||
|
description='',
|
||||||
|
session=self.session)
|
||||||
|
self.assertEqual(pool.pool, self.pools[0].pool)
|
||||||
|
self.assertEqual(pool.slots, 5)
|
||||||
|
self.assertEqual(pool.description, '')
|
||||||
|
self.assertEqual(self.session.query(models.Pool).count(), 2)
|
||||||
|
|
||||||
|
def test_create_pool_bad_name(self):
|
||||||
|
for name in ('', ' '):
|
||||||
|
self.assertRaisesRegexp(pool_api.PoolBadRequest,
|
||||||
|
"^Pool name shouldn't be empty$",
|
||||||
|
pool_api.create_pool,
|
||||||
|
name=name,
|
||||||
|
slots=5,
|
||||||
|
description='',
|
||||||
|
session=self.session)
|
||||||
|
|
||||||
|
def test_create_pool_bad_slots(self):
|
||||||
|
self.assertRaisesRegexp(pool_api.PoolBadRequest,
|
||||||
|
"^Bad value for `slots`: foo$",
|
||||||
|
pool_api.create_pool,
|
||||||
|
name='foo',
|
||||||
|
slots='foo',
|
||||||
|
description='',
|
||||||
|
session=self.session)
|
||||||
|
|
||||||
|
def test_delete_pool(self):
|
||||||
|
pool = pool_api.delete_pool(name=self.pools[0].pool,
|
||||||
|
session=self.session)
|
||||||
|
self.assertEqual(pool.pool, self.pools[0].pool)
|
||||||
|
self.assertEqual(self.session.query(models.Pool).count(), 1)
|
||||||
|
|
||||||
|
def test_delete_pool_non_existing(self):
|
||||||
|
self.assertRaisesRegexp(pool_api.PoolNotFound,
|
||||||
|
"^Pool 'test' doesn't exist$",
|
||||||
|
pool_api.delete_pool,
|
||||||
|
name='test',
|
||||||
|
session=self.session)
|
||||||
|
|
||||||
|
def test_delete_pool_bad_name(self):
|
||||||
|
for name in ('', ' '):
|
||||||
|
self.assertRaisesRegexp(pool_api.PoolBadRequest,
|
||||||
|
"^Pool name shouldn't be empty$",
|
||||||
|
pool_api.delete_pool,
|
||||||
|
name=name,
|
||||||
|
session=self.session)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
|
@ -1062,12 +1062,34 @@ class CoreTest(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class CliTests(unittest.TestCase):
|
class CliTests(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super(CliTests, cls).setUpClass()
|
||||||
|
cls._cleanup()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super(CliTests, self).setUp()
|
||||||
configuration.load_test_config()
|
configuration.load_test_config()
|
||||||
app = application.create_app()
|
app = application.create_app()
|
||||||
app.config['TESTING'] = True
|
app.config['TESTING'] = True
|
||||||
self.parser = cli.CLIFactory.get_parser()
|
self.parser = cli.CLIFactory.get_parser()
|
||||||
self.dagbag = models.DagBag(dag_folder=DEV_NULL, include_examples=True)
|
self.dagbag = models.DagBag(dag_folder=DEV_NULL, include_examples=True)
|
||||||
|
self.session = Session()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self._cleanup(session=self.session)
|
||||||
|
super(CliTests, self).tearDown()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cleanup(session=None):
|
||||||
|
if session is None:
|
||||||
|
session = Session()
|
||||||
|
|
||||||
|
session.query(models.Pool).delete()
|
||||||
|
session.query(models.Variable).delete()
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
||||||
|
|
||||||
def test_cli_list_dags(self):
|
def test_cli_list_dags(self):
|
||||||
args = self.parser.parse_args(['list_dags', '--report'])
|
args = self.parser.parse_args(['list_dags', '--report'])
|
||||||
|
@ -1100,8 +1122,8 @@ class CliTests(unittest.TestCase):
|
||||||
cli.connections(self.parser.parse_args(['connections', '--list']))
|
cli.connections(self.parser.parse_args(['connections', '--list']))
|
||||||
stdout = mock_stdout.getvalue()
|
stdout = mock_stdout.getvalue()
|
||||||
conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]]
|
conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]]
|
||||||
for ii, line in enumerate(stdout.split('\n'))
|
for ii, line in enumerate(stdout.split('\n'))
|
||||||
if ii % 2 == 1]
|
if ii % 2 == 1]
|
||||||
conns = [conn for conn in conns if len(conn) > 0]
|
conns = [conn for conn in conns if len(conn) > 0]
|
||||||
|
|
||||||
# Assert that some of the connections are present in the output as
|
# Assert that some of the connections are present in the output as
|
||||||
|
@ -1365,14 +1387,27 @@ class CliTests(unittest.TestCase):
|
||||||
'-c', 'NOT JSON'])
|
'-c', 'NOT JSON'])
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_pool(self):
|
def test_pool_create(self):
|
||||||
# Checks if all subcommands are properly received
|
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
|
||||||
cli.pool(self.parser.parse_args([
|
self.assertEqual(self.session.query(models.Pool).count(), 1)
|
||||||
'pool', '-s', 'foo', '1', '"my foo pool"']))
|
|
||||||
cli.pool(self.parser.parse_args([
|
def test_pool_get(self):
|
||||||
'pool', '-g', 'foo']))
|
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
|
||||||
cli.pool(self.parser.parse_args([
|
try:
|
||||||
'pool', '-x', 'foo']))
|
cli.pool(self.parser.parse_args(['pool', '-g', 'foo']))
|
||||||
|
except Exception as e:
|
||||||
|
self.fail("The 'pool -g foo' command raised unexpectedly: %s" % e)
|
||||||
|
|
||||||
|
def test_pool_delete(self):
|
||||||
|
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
|
||||||
|
cli.pool(self.parser.parse_args(['pool', '-x', 'foo']))
|
||||||
|
self.assertEqual(self.session.query(models.Pool).count(), 0)
|
||||||
|
|
||||||
|
def test_pool_no_args(self):
|
||||||
|
try:
|
||||||
|
cli.pool(self.parser.parse_args(['pool']))
|
||||||
|
except Exception as e:
|
||||||
|
self.fail("The 'pool' command raised unexpectedly: %s" % e)
|
||||||
|
|
||||||
def test_variables(self):
|
def test_variables(self):
|
||||||
# Checks if all subcommands are properly received
|
# Checks if all subcommands are properly received
|
||||||
|
@ -1426,10 +1461,6 @@ class CliTests(unittest.TestCase):
|
||||||
self.assertEqual('original', models.Variable.get('bar'))
|
self.assertEqual('original', models.Variable.get('bar'))
|
||||||
self.assertEqual('{"foo": "bar"}', models.Variable.get('foo'))
|
self.assertEqual('{"foo": "bar"}', models.Variable.get('foo'))
|
||||||
|
|
||||||
session = settings.Session()
|
|
||||||
session.query(Variable).delete()
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
os.remove('variables1.json')
|
os.remove('variables1.json')
|
||||||
os.remove('variables2.json')
|
os.remove('variables2.json')
|
||||||
|
|
||||||
|
|
|
@ -19,23 +19,36 @@ from urllib.parse import quote_plus
|
||||||
|
|
||||||
from airflow import configuration
|
from airflow import configuration
|
||||||
from airflow.api.common.experimental.trigger_dag import trigger_dag
|
from airflow.api.common.experimental.trigger_dag import trigger_dag
|
||||||
from airflow.models import DagBag, DagRun, TaskInstance
|
from airflow.models import DagBag, DagRun, Pool, TaskInstance
|
||||||
from airflow.settings import Session
|
from airflow.settings import Session
|
||||||
from airflow.www import app as application
|
from airflow.www import app as application
|
||||||
|
|
||||||
|
|
||||||
class ApiExperimentalTests(unittest.TestCase):
|
class TestApiExperimental(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
@classmethod
|
||||||
configuration.load_test_config()
|
def setUpClass(cls):
|
||||||
app = application.create_app(testing=True)
|
super(TestApiExperimental, cls).setUpClass()
|
||||||
self.app = app.test_client()
|
|
||||||
session = Session()
|
session = Session()
|
||||||
session.query(DagRun).delete()
|
session.query(DagRun).delete()
|
||||||
session.query(TaskInstance).delete()
|
session.query(TaskInstance).delete()
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(TestApiExperimental, self).setUp()
|
||||||
|
configuration.load_test_config()
|
||||||
|
app = application.create_app(testing=True)
|
||||||
|
self.app = app.test_client()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
session = Session()
|
||||||
|
session.query(DagRun).delete()
|
||||||
|
session.query(TaskInstance).delete()
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
||||||
|
super(TestApiExperimental, self).tearDown()
|
||||||
|
|
||||||
def test_task_info(self):
|
def test_task_info(self):
|
||||||
url_template = '/api/experimental/dags/{}/tasks/{}'
|
url_template = '/api/experimental/dags/{}/tasks/{}'
|
||||||
|
|
||||||
|
@ -62,7 +75,7 @@ class ApiExperimentalTests(unittest.TestCase):
|
||||||
url_template = '/api/experimental/dags/{}/dag_runs'
|
url_template = '/api/experimental/dags/{}/dag_runs'
|
||||||
response = self.app.post(
|
response = self.app.post(
|
||||||
url_template.format('example_bash_operator'),
|
url_template.format('example_bash_operator'),
|
||||||
data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())),
|
data=json.dumps({'run_id': 'my_run' + datetime.now().isoformat()}),
|
||||||
content_type="application/json"
|
content_type="application/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -70,7 +83,7 @@ class ApiExperimentalTests(unittest.TestCase):
|
||||||
|
|
||||||
response = self.app.post(
|
response = self.app.post(
|
||||||
url_template.format('does_not_exist_dag'),
|
url_template.format('does_not_exist_dag'),
|
||||||
data=json.dumps(dict()),
|
data=json.dumps({}),
|
||||||
content_type="application/json"
|
content_type="application/json"
|
||||||
)
|
)
|
||||||
self.assertEqual(404, response.status_code)
|
self.assertEqual(404, response.status_code)
|
||||||
|
@ -88,7 +101,7 @@ class ApiExperimentalTests(unittest.TestCase):
|
||||||
# Test Correct execution
|
# Test Correct execution
|
||||||
response = self.app.post(
|
response = self.app.post(
|
||||||
url_template.format(dag_id),
|
url_template.format(dag_id),
|
||||||
data=json.dumps(dict(execution_date=datetime_string)),
|
data=json.dumps({'execution_date': datetime_string}),
|
||||||
content_type="application/json"
|
content_type="application/json"
|
||||||
)
|
)
|
||||||
self.assertEqual(200, response.status_code)
|
self.assertEqual(200, response.status_code)
|
||||||
|
@ -103,7 +116,7 @@ class ApiExperimentalTests(unittest.TestCase):
|
||||||
# Test error for nonexistent dag
|
# Test error for nonexistent dag
|
||||||
response = self.app.post(
|
response = self.app.post(
|
||||||
url_template.format('does_not_exist_dag'),
|
url_template.format('does_not_exist_dag'),
|
||||||
data=json.dumps(dict(execution_date=execution_date.isoformat())),
|
data=json.dumps({'execution_date': execution_date.isoformat()}),
|
||||||
content_type="application/json"
|
content_type="application/json"
|
||||||
)
|
)
|
||||||
self.assertEqual(404, response.status_code)
|
self.assertEqual(404, response.status_code)
|
||||||
|
@ -111,7 +124,7 @@ class ApiExperimentalTests(unittest.TestCase):
|
||||||
# Test error for bad datetime format
|
# Test error for bad datetime format
|
||||||
response = self.app.post(
|
response = self.app.post(
|
||||||
url_template.format(dag_id),
|
url_template.format(dag_id),
|
||||||
data=json.dumps(dict(execution_date='not_a_datetime')),
|
data=json.dumps({'execution_date': 'not_a_datetime'}),
|
||||||
content_type="application/json"
|
content_type="application/json"
|
||||||
)
|
)
|
||||||
self.assertEqual(400, response.status_code)
|
self.assertEqual(400, response.status_code)
|
||||||
|
@ -122,7 +135,9 @@ class ApiExperimentalTests(unittest.TestCase):
|
||||||
task_id = 'also_run_this'
|
task_id = 'also_run_this'
|
||||||
execution_date = datetime.now().replace(microsecond=0)
|
execution_date = datetime.now().replace(microsecond=0)
|
||||||
datetime_string = quote_plus(execution_date.isoformat())
|
datetime_string = quote_plus(execution_date.isoformat())
|
||||||
wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat())
|
wrong_datetime_string = quote_plus(
|
||||||
|
datetime(1990, 1, 1, 1, 1, 1).isoformat()
|
||||||
|
)
|
||||||
|
|
||||||
# Create DagRun
|
# Create DagRun
|
||||||
trigger_dag(dag_id=dag_id,
|
trigger_dag(dag_id=dag_id,
|
||||||
|
@ -139,7 +154,8 @@ class ApiExperimentalTests(unittest.TestCase):
|
||||||
|
|
||||||
# Test error for nonexistent dag
|
# Test error for nonexistent dag
|
||||||
response = self.app.get(
|
response = self.app.get(
|
||||||
url_template.format('does_not_exist_dag', datetime_string, task_id),
|
url_template.format('does_not_exist_dag', datetime_string,
|
||||||
|
task_id),
|
||||||
)
|
)
|
||||||
self.assertEqual(404, response.status_code)
|
self.assertEqual(404, response.status_code)
|
||||||
self.assertIn('error', response.data.decode('utf-8'))
|
self.assertIn('error', response.data.decode('utf-8'))
|
||||||
|
@ -164,3 +180,122 @@ class ApiExperimentalTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(400, response.status_code)
|
self.assertEqual(400, response.status_code)
|
||||||
self.assertIn('error', response.data.decode('utf-8'))
|
self.assertIn('error', response.data.decode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
|
class TestPoolApiExperimental(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super(TestPoolApiExperimental, cls).setUpClass()
|
||||||
|
session = Session()
|
||||||
|
session.query(Pool).delete()
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(TestPoolApiExperimental, self).setUp()
|
||||||
|
configuration.load_test_config()
|
||||||
|
app = application.create_app(testing=True)
|
||||||
|
self.app = app.test_client()
|
||||||
|
self.session = Session()
|
||||||
|
self.pools = []
|
||||||
|
for i in range(2):
|
||||||
|
name = 'experimental_%s' % (i + 1)
|
||||||
|
pool = Pool(
|
||||||
|
pool=name,
|
||||||
|
slots=i,
|
||||||
|
description=name,
|
||||||
|
)
|
||||||
|
self.session.add(pool)
|
||||||
|
self.pools.append(pool)
|
||||||
|
self.session.commit()
|
||||||
|
self.pool = self.pools[0]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.session.query(Pool).delete()
|
||||||
|
self.session.commit()
|
||||||
|
self.session.close()
|
||||||
|
super(TestPoolApiExperimental, self).tearDown()
|
||||||
|
|
||||||
|
def _get_pool_count(self):
|
||||||
|
response = self.app.get('/api/experimental/pools')
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
return len(json.loads(response.data.decode('utf-8')))
|
||||||
|
|
||||||
|
def test_get_pool(self):
|
||||||
|
response = self.app.get(
|
||||||
|
'/api/experimental/pools/{}'.format(self.pool.pool),
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(json.loads(response.data.decode('utf-8')),
|
||||||
|
self.pool.to_json())
|
||||||
|
|
||||||
|
def test_get_pool_non_existing(self):
|
||||||
|
response = self.app.get('/api/experimental/pools/foo')
|
||||||
|
self.assertEqual(response.status_code, 404)
|
||||||
|
self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
|
||||||
|
"Pool 'foo' doesn't exist")
|
||||||
|
|
||||||
|
def test_get_pools(self):
|
||||||
|
response = self.app.get('/api/experimental/pools')
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
pools = json.loads(response.data.decode('utf-8'))
|
||||||
|
self.assertEqual(len(pools), 2)
|
||||||
|
for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])):
|
||||||
|
self.assertDictEqual(pool, self.pools[i].to_json())
|
||||||
|
|
||||||
|
def test_create_pool(self):
|
||||||
|
response = self.app.post(
|
||||||
|
'/api/experimental/pools',
|
||||||
|
data=json.dumps({
|
||||||
|
'name': 'foo',
|
||||||
|
'slots': 1,
|
||||||
|
'description': '',
|
||||||
|
}),
|
||||||
|
content_type='application/json',
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
pool = json.loads(response.data.decode('utf-8'))
|
||||||
|
self.assertEqual(pool['pool'], 'foo')
|
||||||
|
self.assertEqual(pool['slots'], 1)
|
||||||
|
self.assertEqual(pool['description'], '')
|
||||||
|
self.assertEqual(self._get_pool_count(), 3)
|
||||||
|
|
||||||
|
def test_create_pool_with_bad_name(self):
|
||||||
|
for name in ('', ' '):
|
||||||
|
response = self.app.post(
|
||||||
|
'/api/experimental/pools',
|
||||||
|
data=json.dumps({
|
||||||
|
'name': name,
|
||||||
|
'slots': 1,
|
||||||
|
'description': '',
|
||||||
|
}),
|
||||||
|
content_type='application/json',
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 400)
|
||||||
|
self.assertEqual(
|
||||||
|
json.loads(response.data.decode('utf-8'))['error'],
|
||||||
|
"Pool name shouldn't be empty",
|
||||||
|
)
|
||||||
|
self.assertEqual(self._get_pool_count(), 2)
|
||||||
|
|
||||||
|
def test_delete_pool(self):
|
||||||
|
response = self.app.delete(
|
||||||
|
'/api/experimental/pools/{}'.format(self.pool.pool),
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(json.loads(response.data.decode('utf-8')),
|
||||||
|
self.pool.to_json())
|
||||||
|
self.assertEqual(self._get_pool_count(), 1)
|
||||||
|
|
||||||
|
def test_delete_pool_non_existing(self):
|
||||||
|
response = self.app.delete(
|
||||||
|
'/api/experimental/pools/foo',
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 404)
|
||||||
|
self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
|
||||||
|
"Pool 'foo' doesn't exist")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче