Make Dag Serialization a hard requirement (#11335)

Scheduler HA uses Serialized DAGs and hence it is a strict
the requirement for 2.0.

It also has performance benefits for the Webserver and so should
be used by default anyway.

Task execution on workers will continue to use the actual files for execution.

Scheduler, Experimental API and Webserver will read the DAGs from DB using
`DagBag(read_dags_from_db=True)`
This commit is contained in:
Kaxil Naik 2020-10-26 18:16:29 +00:00 коммит произвёл GitHub
Родитель 0d3ee66924
Коммит 406ed29252
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
32 изменённых файлов: 514 добавлений и 590 удалений

Просмотреть файл

@ -256,6 +256,14 @@ The Old and New provider configuration keys that have changed are as follows
For more information, visit https://flask-appbuilder.readthedocs.io/en/latest/security.html#authentication-oauth
### DAG Serialization will be strictly required
Until Airflow 2.0, DAG Serialization was disabled by default. However from Airflow 2.0, DAG Serialization
will be enabled by default. From Airflow 2.0, Scheduler will use Serialized DAGs to make scheduling decisions
and so DAG Serialization can not be turned off.
The previous setting `[core] store_serialized_dags` will be ignored.
### Changes to the KubernetesExecutor
#### The KubernetesExecutor Will No Longer Read from the airflow.cfg for Base Pod Configurations

Просмотреть файл

@ -19,7 +19,6 @@
from datetime import datetime
from typing import Optional
from airflow.configuration import conf
from airflow.exceptions import DagNotFound, DagRunNotFound, TaskNotFound
from airflow.models import DagBag, DagModel, DagRun
@ -32,10 +31,10 @@ def check_and_get_dag(dag_id: str, task_id: Optional[str] = None) -> DagModel:
dagbag = DagBag(
dag_folder=dag_model.fileloc,
read_dags_from_db=conf.getboolean('core', 'store_serialized_dags')
read_dags_from_db=True
)
dag = dagbag.get_dag(dag_id) # prefetch dag if it is stored serialized
if dag_id not in dagbag.dags:
dag = dagbag.get_dag(dag_id)
if not dag:
error_message = "Dag id {} not found".format(dag_id)
raise DagNotFound(error_message)
if task_id and not dag.has_task(task_id):

Просмотреть файл

@ -24,7 +24,6 @@ from airflow import models
from airflow.exceptions import DagNotFound
from airflow.models import DagModel, TaskFail
from airflow.models.serialized_dag import SerializedDagModel
from airflow.settings import STORE_SERIALIZED_DAGS
from airflow.utils.session import provide_session
log = logging.getLogger(__name__)
@ -47,7 +46,7 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> i
# Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval.
# There may be a lag, so explicitly removes serialized DAG here.
if STORE_SERIALIZED_DAGS and SerializedDagModel.has_dag(dag_id=dag_id, session=session):
if SerializedDagModel.has_dag(dag_id=dag_id, session=session):
SerializedDagModel.remove_dag(dag_id=dag_id, session=session)
count = 0

Просмотреть файл

@ -61,15 +61,16 @@ def _create_dagruns(dag, execution_dates, state, run_type):
@provide_session
def set_state(
tasks: Iterable[BaseOperator],
execution_date: datetime.datetime,
upstream: bool = False,
downstream: bool = False,
future: bool = False,
past: bool = False,
state: str = State.SUCCESS,
commit: bool = False,
session=None): # pylint: disable=too-many-arguments,too-many-locals
tasks: Iterable[BaseOperator],
execution_date: datetime.datetime,
upstream: bool = False,
downstream: bool = False,
future: bool = False,
past: bool = False,
state: str = State.SUCCESS,
commit: bool = False,
session=None
): # pylint: disable=too-many-arguments,too-many-locals
"""
Set the state of a task instance and if needed its relatives. Can set state
for future tasks (calculated from execution_date) and retroactively

Просмотреть файл

@ -114,13 +114,7 @@ def trigger_dag(
if dag_model is None:
raise DagNotFound("Dag id {} not found in DagModel".format(dag_id))
def read_store_serialized_dags():
from airflow.configuration import conf
return conf.getboolean('core', 'store_serialized_dags')
dagbag = DagBag(
dag_folder=dag_model.fileloc,
read_dags_from_db=read_store_serialized_dags()
)
dagbag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
triggers = _trigger_dag(
dag_id=dag_id,
dag_bag=dagbag,

Просмотреть файл

@ -34,6 +34,7 @@ from airflow.api_connexion.schemas.task_instance_schema import (
TaskInstanceReferenceCollection,
set_task_instance_state_form,
)
from airflow.exceptions import SerializedDagNotFound
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import clear_task_instances, TaskInstance as TI
from airflow.models import SlaMiss
@ -286,9 +287,13 @@ def post_set_task_instances_state(dag_id, session):
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
dag = current_app.dag_bag.get_dag(dag_id)
if not dag:
error_message = "Dag ID {} not found".format(dag_id)
error_message = "Dag ID {} not found".format(dag_id)
try:
dag = current_app.dag_bag.get_dag(dag_id)
if not dag:
raise NotFound(error_message)
except SerializedDagNotFound:
# If DAG is not found in serialized_dag table
raise NotFound(error_message)
task_id = data['task_id']

Просмотреть файл

@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
"""Sync permission command"""
from airflow import settings
from airflow.models import DagBag
from airflow.utils import cli as cli_utils
from airflow.www.app import cached_app
@ -29,7 +28,7 @@ def sync_perm(args):
print('Updating permission, view-menu for all existing roles')
appbuilder.sm.sync_roles()
print('Updating permission on all DAG views')
dags = DagBag(store_serialized_dags=settings.STORE_SERIALIZED_DAGS).dags.values()
dags = DagBag(read_dags_from_db=True).dags.values()
for dag in dags:
appbuilder.sm.sync_perm_for_dag(
dag.dag_id,

Просмотреть файл

@ -126,6 +126,10 @@ class DuplicateTaskIdFound(AirflowException):
"""Raise when a Task with duplicate task_id is defined in the same DAG"""
class SerializedDagNotFound(DagNotFound):
"""Raise when DAG is not found in the serialized_dags table in DB"""
class TaskNotFound(AirflowNotFoundException):
"""Raise when a Task is not available in the system"""

Просмотреть файл

@ -36,7 +36,7 @@ from tabulate import tabulate
from airflow import settings
from airflow.configuration import conf
from airflow.dag.base_dag import BaseDagBag
from airflow.exceptions import AirflowClusterPolicyViolation, AirflowDagCycleException
from airflow.exceptions import AirflowClusterPolicyViolation, AirflowDagCycleException, SerializedDagNotFound
from airflow.plugins_manager import integrate_dag_plugins
from airflow.stats import Stats
from airflow.utils import timezone
@ -75,10 +75,8 @@ class DagBag(BaseDagBag, LoggingMixin):
:param include_smart_sensor: whether to include the smart sensor native
DAGs that create the smart sensor operators for whole cluster
:type include_smart_sensor: bool
:param read_dags_from_db: Read DAGs from DB if store_serialized_dags is ``True``.
If ``False`` DAGs are read from python files. This property is not used when
determining whether or not to write Serialized DAGs, that is done by checking
the config ``store_serialized_dags``.
:param read_dags_from_db: Read DAGs from DB if ``True`` is passed.
If ``False`` DAGs are read from python files.
:type read_dags_from_db: bool
"""
@ -214,7 +212,7 @@ class DagBag(BaseDagBag, LoggingMixin):
from airflow.models.serialized_dag import SerializedDagModel
row = SerializedDagModel.get(dag_id, session)
if not row:
raise ValueError(f"DAG '{dag_id}' not found in serialized_dag table")
raise SerializedDagNotFound(f"DAG '{dag_id}' not found in serialized_dag table")
dag = row.dag
for subdag in dag.subdags:
@ -527,8 +525,6 @@ class DagBag(BaseDagBag, LoggingMixin):
from airflow.models.serialized_dag import SerializedDagModel
self.log.debug("Calling the DAG.bulk_sync_to_db method")
DAG.bulk_write_to_db(self.dags.values(), session=session)
# Write Serialized DAGs to DB if DAG Serialization is turned on
# Even though self.read_dags_from_db is False
if settings.STORE_SERIALIZED_DAGS or self.read_dags_from_db:
self.log.debug("Calling the SerializedDagModel.bulk_sync_to_db method")
SerializedDagModel.bulk_sync_to_db(self.dags.values(), session=session)
# Write Serialized DAGs to DB
self.log.debug("Calling the SerializedDagModel.bulk_sync_to_db method")
SerializedDagModel.bulk_sync_to_db(self.dags.values(), session=session)

Просмотреть файл

@ -39,7 +39,6 @@ class DagCode(Base):
dag_code table contains code of DAG files synchronized by scheduler.
This feature is controlled by:
* ``[core] store_serialized_dags = True``: enable this feature
* ``[core] store_dag_code = True``: enable this feature
For details on dag serialization see SerializedDagModel

Просмотреть файл

@ -47,7 +47,6 @@ class SerializedDagModel(Base):
serialized_dag table is a snapshot of DAG files synchronized by scheduler.
This feature is controlled by:
* ``[core] store_serialized_dags = True``: enable this feature
* ``[core] min_serialized_dag_update_interval = 30`` (s):
serialized DAGs are updated in DB when a file gets processed by scheduler,
to reduce DB write rate, there is a minimal interval of updating serialized DAGs.
@ -55,8 +54,8 @@ class SerializedDagModel(Base):
interval of deleting serialized DAGs in DB when the files are deleted, suggest
to use a smaller interval such as 60
It is used by webserver to load dags when ``store_serialized_dags=True``.
Because reading from database is lightweight compared to importing from files,
It is used by webserver to load dags
because reading from database is lightweight compared to importing from files,
it solves the webserver scalability issue.
"""

Просмотреть файл

@ -51,7 +51,6 @@ from airflow.models.taskreschedule import TaskReschedule
from airflow.models.variable import Variable
from airflow.models.xcom import XCOM_RETURN_KEY, XCom
from airflow.sentry import Sentry
from airflow.settings import STORE_SERIALIZED_DAGS
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
@ -1160,9 +1159,8 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
start_time = time.time()
self.render_templates(context=context)
if STORE_SERIALIZED_DAGS:
RenderedTaskInstanceFields.write(RenderedTaskInstanceFields(ti=self, render_templates=False))
RenderedTaskInstanceFields.delete_old_records(self.task_id, self.dag_id)
RenderedTaskInstanceFields.write(RenderedTaskInstanceFields(ti=self, render_templates=False))
RenderedTaskInstanceFields.delete_old_records(self.task_id, self.dag_id)
# Export context to make it available for operators to use.
airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
@ -1576,28 +1574,22 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
}
def get_rendered_template_fields(self):
"""
Fetch rendered template fields from DB if Serialization is enabled.
Else just render the templates
"""
"""Fetch rendered template fields from DB"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
if STORE_SERIALIZED_DAGS:
rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self)
if rendered_task_instance_fields:
for field_name, rendered_value in rendered_task_instance_fields.items():
setattr(self.task, field_name, rendered_value)
else:
try:
self.render_templates()
except (TemplateAssertionError, UndefinedError) as e:
raise AirflowException(
"Webserver does not have access to User-defined Macros or Filters "
"when Dag Serialization is enabled. Hence for the task that have not yet "
"started running, please use 'airflow tasks render' for debugging the "
"rendering of template_fields."
) from e
rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self)
if rendered_task_instance_fields:
for field_name, rendered_value in rendered_task_instance_fields.items():
setattr(self.task, field_name, rendered_value)
else:
self.render_templates()
try:
self.render_templates()
except (TemplateAssertionError, UndefinedError) as e:
raise AirflowException(
"Webserver does not have access to User-defined Macros or Filters "
"when Dag Serialization is enabled. Hence for the task that have not yet "
"started running, please use 'airflow tasks render' for debugging the "
"rendering of template_fields."
) from e
def overwrite_params_with_dag_run_conf(self, params, dag_run):
"""Overwrite Task Params with DagRun.conf"""

Просмотреть файл

@ -20,7 +20,6 @@ import datetime
from typing import Dict, Optional, Union
from urllib.parse import quote
from airflow import settings
from airflow.api.common.experimental.trigger_dag import trigger_dag
from airflow.exceptions import DagNotFound, DagRunAlreadyExists
from airflow.models import BaseOperator, BaseOperatorLink, DagBag, DagModel, DagRun
@ -122,7 +121,7 @@ class TriggerDagRunOperator(BaseOperator):
dag_bag = DagBag(
dag_folder=dag_model.fileloc,
store_serialized_dags=settings.STORE_SERIALIZED_DAGS
read_dags_from_db=True
)
dag = dag_bag.get_dag(self.trigger_dag_id)

Просмотреть файл

@ -344,9 +344,6 @@ MEGABYTE = KILOBYTE * KILOBYTE
WEB_COLORS = {'LIGHTBLUE': '#4d9de0',
'LIGHTORANGE': '#FF9933'}
# If store_serialized_dags is True, scheduler writes serialized DAGs to DB, and webserver
# reads DAGs from DB instead of importing from files.
STORE_SERIALIZED_DAGS = conf.getboolean('core', 'store_serialized_dags', fallback=False)
# Updating serialized DAG can not be faster than a minimum interval to reduce database
# write rate.
@ -360,8 +357,7 @@ MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint(
# Whether to persist DAG files code in DB. If set to True, Webserver reads file contents
# from DB instead of trying to access files in a DAG folder.
# Defaults to same as the store_serialized_dags setting.
STORE_DAG_CODE = conf.getboolean("core", "store_dag_code", fallback=STORE_SERIALIZED_DAGS)
STORE_DAG_CODE = conf.getboolean("core", "store_dag_code", fallback=True)
# If donot_modify_handlers=True, we do not modify logging handlers in task_run command
# If the flag is set to False, we remove all handlers from the root logger

Просмотреть файл

@ -38,9 +38,10 @@ from tabulate import tabulate
import airflow.models
from airflow.configuration import conf
from airflow.models import errors
from airflow.models import DagModel, errors
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.settings import STORE_DAG_CODE, STORE_SERIALIZED_DAGS
from airflow.settings import STORE_DAG_CODE
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.callback_requests import CallbackRequest, SlaCallbackRequest, TaskCallbackRequest
@ -734,11 +735,8 @@ class DagFileProcessorManager(LoggingMixin): # pylint: disable=too-many-instanc
except Exception: # noqa pylint: disable=broad-except
self.log.exception("Error removing old import errors")
if STORE_SERIALIZED_DAGS:
from airflow.models.dag import DagModel
from airflow.models.serialized_dag import SerializedDagModel
SerializedDagModel.remove_deleted_dags(self._file_paths)
DagModel.deactivate_deleted_dags(self._file_paths)
SerializedDagModel.remove_deleted_dags(self._file_paths)
DagModel.deactivate_deleted_dags(self._file_paths)
if self.store_dag_code:
from airflow.models.dagcode import DagCode

Просмотреть файл

@ -18,7 +18,7 @@
import os
from airflow.models import DagBag
from airflow.settings import DAGS_FOLDER, STORE_SERIALIZED_DAGS
from airflow.settings import DAGS_FOLDER
def init_dagbag(app):
@ -29,4 +29,4 @@ def init_dagbag(app):
if os.environ.get('SKIP_DAGS_PARSING') == 'True':
app.dag_bag = DagBag(os.devnull, include_examples=False)
else:
app.dag_bag = DagBag(DAGS_FOLDER, read_dags_from_db=STORE_SERIALIZED_DAGS)
app.dag_bag = DagBag(DAGS_FOLDER, read_dags_from_db=True)

Просмотреть файл

@ -1867,7 +1867,12 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m
"""Get Dag as duration graph."""
default_dag_run = conf.getint('webserver', 'default_dag_run_display_number')
dag_id = request.args.get('dag_id')
dag = current_app.dag_bag.get_dag(dag_id)
try:
dag = current_app.dag_bag.get_dag(dag_id)
except airflow.exceptions.SerializedDagNotFound:
dag = None
base_date = request.args.get('base_date')
num_runs = request.args.get('num_runs')
num_runs = int(num_runs) if num_runs else default_dag_run
@ -2159,10 +2164,7 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m
@action_logging
def refresh_all(self):
"""Refresh everything"""
if settings.STORE_SERIALIZED_DAGS:
current_app.dag_bag.collect_dags_from_db()
else:
current_app.dag_bag.collect_dags(only_if_updated=False)
current_app.dag_bag.collect_dags_from_db()
# sync permissions for all dags
for dag_id, dag in current_app.dag_bag.dags.items():

Просмотреть файл

@ -55,15 +55,18 @@ The data is stored in the :class:`airflow.models.renderedtifields.RenderedTaskIn
To limit the excessive growth of the database, only the most recent entries are kept and older entries
are purged.
Enable Dag Serialization
------------------------
.. note::
From Airflow 2.0 DAG Serialization is a strictly required and can not be turned off.
Dag Serialization Settings
---------------------------
Add the following settings in ``airflow.cfg``:
.. code-block:: ini
[core]
store_serialized_dags = True
store_dag_code = True
# You can also update the following default configurations based on your needs
@ -71,8 +74,6 @@ Add the following settings in ``airflow.cfg``:
min_serialized_dag_fetch_interval = 10
max_num_rendered_ti_fields_per_task = 30
* ``store_serialized_dags``: This option decides whether to serialise DAGs and persist them in DB.
If set to True, Webserver reads from DB instead of parsing DAG files
* ``store_dag_code``: This option decides whether to persist DAG files code in DB.
If set to True, Webserver reads file contents from DB instead of trying to access files in a DAG folder.
* ``min_serialized_dag_update_interval``: This flag sets the minimum interval (in seconds) after which

Просмотреть файл

@ -42,7 +42,9 @@ class TestMarkTasks(unittest.TestCase):
@classmethod
def setUpClass(cls):
dagbag = models.DagBag(include_examples=True)
models.DagBag(include_examples=True, read_dags_from_db=False).sync_to_db()
dagbag = models.DagBag(include_examples=True, read_dags_from_db=True)
dagbag.collect_dags_from_db()
cls.dag1 = dagbag.dags['example_bash_operator']
cls.dag1.sync_to_db()
cls.dag2 = dagbag.dags['example_subdag_operator']
@ -281,7 +283,7 @@ class TestMarkDAGRun(unittest.TestCase):
@classmethod
def setUpClass(cls):
dagbag = models.DagBag(include_examples=True)
dagbag = models.DagBag(include_examples=True, read_dags_from_db=False)
cls.dag1 = dagbag.dags['example_bash_operator']
cls.dag1.sync_to_db()
cls.dag2 = dagbag.dags['example_subdag_operator']

Просмотреть файл

@ -112,7 +112,6 @@ class TestGetTask(TestTaskEndpoint):
assert response.status_code == 200
assert response.json == expected
@conf_vars({("core", "store_serialized_dags"): "True"})
def test_should_response_200_serialized(self):
# Create empty app with empty dagbag to check if DAG is read from db
with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}):

Просмотреть файл

@ -79,7 +79,8 @@ class TestTaskInstanceEndpoint(unittest.TestCase):
self.client = self.app.test_client() # type:ignore
clear_db_runs()
clear_db_sla_miss()
self.dagbag = DagBag(include_examples=True)
DagBag(include_examples=True, read_dags_from_db=False).sync_to_db()
self.dagbag = DagBag(include_examples=True, read_dags_from_db=True)
def create_task_instances(
self,
@ -92,7 +93,7 @@ class TestTaskInstanceEndpoint(unittest.TestCase):
):
"""Method to create task instances using kwargs and default arguments"""
dag = self.dagbag.dags[dag_id]
dag = self.dagbag.get_dag(dag_id)
tasks = dag.tasks
counter = len(tasks)
if task_instances is not None:

Просмотреть файл

@ -34,7 +34,6 @@ class TestCliSyncPerm(unittest.TestCase):
@mock.patch("airflow.cli.commands.sync_perm_command.cached_app")
@mock.patch("airflow.cli.commands.sync_perm_command.DagBag")
@mock.patch("airflow.settings.STORE_SERIALIZED_DAGS", True)
def test_cli_sync_perm(self, dagbag_mock, mock_cached_app):
self.expect_dagbag_contains([
DAG('has_access_control',
@ -53,7 +52,7 @@ class TestCliSyncPerm(unittest.TestCase):
assert appbuilder.sm.sync_roles.call_count == 1
dagbag_mock.assert_called_once_with(store_serialized_dags=True)
dagbag_mock.assert_called_once_with(read_dags_from_db=True)
self.assertEqual(2, len(appbuilder.sm.sync_perm_for_dag.mock_calls))
appbuilder.sm.sync_perm_for_dag.assert_any_call(
'has_access_control',

Просмотреть файл

@ -67,8 +67,7 @@ class TestCore(unittest.TestCase):
default_scheduler_args = {"num_runs": 1}
def setUp(self):
self.dagbag = DagBag(
dag_folder=DEV_NULL, include_examples=True)
self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True, read_dags_from_db=False)
self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=self.args)
self.dag_bash = self.dagbag.dags['example_bash_operator']

Просмотреть файл

@ -134,7 +134,6 @@ class TestDagFileProcessor(unittest.TestCase):
return dag
@classmethod
@patch("airflow.models.dagbag.settings.STORE_SERIALIZED_DAGS", True)
def setUpClass(cls):
# Ensure the DAGs we are looking at from the DB are up-to-date
non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False)
@ -548,7 +547,10 @@ class TestDagFileProcessor(unittest.TestCase):
scheduler = SchedulerJob()
scheduler.dagbag.bag_dag(dag, root_dag=dag)
scheduler.dagbag.sync_to_db()
# Since we don't want to store the code for the DAG defined in this file
with mock.patch.object(settings, "STORE_DAG_CODE", False):
scheduler.dagbag.sync_to_db()
session = settings.Session()
orm_dag = session.query(DagModel).get(dag.dag_id)
@ -649,7 +651,7 @@ class TestDagFileProcessor(unittest.TestCase):
@patch.object(TaskInstance, 'handle_failure')
def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
dagbag = DagBag(dag_folder="/dev/null", include_examples=True)
dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
with create_session() as session:
session.query(TaskInstance).delete()
@ -730,9 +732,8 @@ class TestDagFileProcessor(unittest.TestCase):
)
# Write DAGs to dag and serialized_dag table
with mock.patch("airflow.models.dagbag.settings.STORE_SERIALIZED_DAGS", return_value=True):
dagbag = DagBag(dag_folder=dag_file, include_examples=False)
dagbag.sync_to_db()
dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False)
dagbag.sync_to_db()
scheduler_job = SchedulerJob()
scheduler_job.processor_agent = mock.MagicMock()
@ -811,8 +812,17 @@ class TestSchedulerJob(unittest.TestCase):
# enqueue!
self.null_exec = MockExecutor()
self.patcher = patch('airflow.utils.dag_processing.SerializedDagModel.remove_deleted_dags')
# Since we don't want to store the code for the DAG defined in this file
self.patcher_dag_code = patch.object(settings, "STORE_DAG_CODE", False)
self.patcher.start()
self.patcher_dag_code.start()
def tearDown(self):
self.patcher.stop()
self.patcher_dag_code.stop()
@classmethod
@patch("airflow.models.dagbag.settings.STORE_SERIALIZED_DAGS", True)
def setUpClass(cls):
# Ensure the DAGs we are looking at from the DB are up-to-date
non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False)
@ -1919,10 +1929,9 @@ class TestSchedulerJob(unittest.TestCase):
op1 = DummyOperator(task_id='op1')
# Write Dag to DB
with mock.patch.object(settings, "STORE_SERIALIZED_DAGS", True):
dagbag = DagBag(dag_folder="/dev/null", include_examples=False)
dagbag.bag_dag(dag, root_dag=dag)
dagbag.sync_to_db()
dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False)
dagbag.bag_dag(dag, root_dag=dag)
dagbag.sync_to_db()
dag = DagBag(read_dags_from_db=True, include_examples=False).get_dag(dag_id)
# Create DAG run with FAILED state
@ -1978,6 +1987,7 @@ class TestSchedulerJob(unittest.TestCase):
scheduler = SchedulerJob()
scheduler.dagbag.bag_dag(dag, root_dag=dag)
scheduler.dagbag.sync_to_db()
session = settings.Session()
@ -2112,8 +2122,9 @@ class TestSchedulerJob(unittest.TestCase):
scheduler._send_sla_callbacks_to_processor = mock.Mock()
# Sync DAG into DB
scheduler.dagbag.bag_dag(dag, root_dag=dag)
scheduler.dagbag.sync_to_db()
with mock.patch.object(settings, "STORE_DAG_CODE", False):
scheduler.dagbag.bag_dag(dag, root_dag=dag)
scheduler.dagbag.sync_to_db()
session = settings.Session()
orm_dag = session.query(DagModel).get(dag.dag_id)
@ -3748,7 +3759,7 @@ class TestSchedulerJobQueriesCount(unittest.TestCase):
('core', 'store_serialized_dags'): 'True',
}), mock.patch.object(settings, 'STORE_SERIALIZED_DAGS', True):
dagruns = []
dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False)
dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False, read_dags_from_db=False)
dagbag.sync_to_db()
dag_ids = dagbag.dag_ids

Просмотреть файл

@ -694,7 +694,7 @@ class TestDag(unittest.TestCase):
DAG(f'dag-bulk-sync-{i}', start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4)
]
with assert_queries_count(5):
with assert_queries_count(7):
DAG.bulk_write_to_db(dags)
with create_session() as session:
self.assertEqual(
@ -711,14 +711,14 @@ class TestDag(unittest.TestCase):
set(session.query(DagTag.dag_id, DagTag.name).all())
)
# Re-sync should do fewer queries
with assert_queries_count(3):
with assert_queries_count(4):
DAG.bulk_write_to_db(dags)
with assert_queries_count(3):
with assert_queries_count(4):
DAG.bulk_write_to_db(dags)
# Adding tags
for dag in dags:
dag.tags.append("test-dag2")
with assert_queries_count(4):
with assert_queries_count(5):
DAG.bulk_write_to_db(dags)
with create_session() as session:
self.assertEqual(
@ -741,7 +741,7 @@ class TestDag(unittest.TestCase):
# Removing tags
for dag in dags:
dag.tags.remove("test-dag")
with assert_queries_count(4):
with assert_queries_count(5):
DAG.bulk_write_to_db(dags)
with create_session() as session:
self.assertEqual(
@ -972,7 +972,8 @@ class TestDag(unittest.TestCase):
)
dag.fileloc = dag_fileloc
session = settings.Session()
dag.sync_to_db(session=session)
with mock.patch.object(settings, "STORE_DAG_CODE", False):
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()

Просмотреть файл

@ -642,7 +642,6 @@ class TestDagBag(unittest.TestCase):
with create_session() as session:
session.query(DagModel).filter(DagModel.dag_id == 'test_deactivate_unknown_dags').delete()
@patch("airflow.models.dagbag.settings.STORE_SERIALIZED_DAGS", True)
def test_serialized_dags_are_written_to_db_on_sync(self):
"""
Test that when dagbag.sync_to_db is called the DAGs are Serialized and written to DB
@ -662,7 +661,6 @@ class TestDagBag(unittest.TestCase):
new_serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar()
self.assertEqual(new_serialized_dags_count, 1)
@patch("airflow.models.dagbag.settings.STORE_SERIALIZED_DAGS", True)
@patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL", 5)
@patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL", 5)
def test_get_dag_with_dag_serialization(self):

Просмотреть файл

@ -478,64 +478,61 @@ class TestTaskInstance(unittest.TestCase):
run_with_error(ti)
self.assertEqual(ti.state, State.FAILED)
@parameterized.expand([
(False, None,),
(True, {'env': None, 'bash_command': 'echo test_retry_handling; exit 1'},),
])
def test_retry_handling(self, dag_serialization, expected_rendered_ti_fields):
def test_retry_handling(self):
"""
Test that task retries are handled properly
"""
with patch("airflow.models.taskinstance.STORE_SERIALIZED_DAGS", dag_serialization):
dag = models.DAG(dag_id='test_retry_handling')
task = BashOperator(
task_id='test_retry_handling_op',
bash_command='echo {{dag.dag_id}}; exit 1',
retries=1,
retry_delay=datetime.timedelta(seconds=0),
dag=dag,
owner='test_pool',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
expected_rendered_ti_fields = {'env': None, 'bash_command': 'echo test_retry_handling; exit 1'}
def run_with_error(ti):
try:
ti.run()
except AirflowException:
pass
dag = models.DAG(dag_id='test_retry_handling')
task = BashOperator(
task_id='test_retry_handling_op',
bash_command='echo {{dag.dag_id}}; exit 1',
retries=1,
retry_delay=datetime.timedelta(seconds=0),
dag=dag,
owner='test_pool',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
ti = TI(
task=task, execution_date=timezone.utcnow())
self.assertEqual(ti.try_number, 1)
def run_with_error(ti):
try:
ti.run()
except AirflowException:
pass
# first run -- up for retry
run_with_error(ti)
self.assertEqual(ti.state, State.UP_FOR_RETRY)
self.assertEqual(ti._try_number, 1)
self.assertEqual(ti.try_number, 2)
ti = TI(
task=task, execution_date=timezone.utcnow())
self.assertEqual(ti.try_number, 1)
# second run -- fail
run_with_error(ti)
self.assertEqual(ti.state, State.FAILED)
self.assertEqual(ti._try_number, 2)
self.assertEqual(ti.try_number, 3)
# first run -- up for retry
run_with_error(ti)
self.assertEqual(ti.state, State.UP_FOR_RETRY)
self.assertEqual(ti._try_number, 1)
self.assertEqual(ti.try_number, 2)
# Clear the TI state since you can't run a task with a FAILED state without
# clearing it first
dag.clear()
# second run -- fail
run_with_error(ti)
self.assertEqual(ti.state, State.FAILED)
self.assertEqual(ti._try_number, 2)
self.assertEqual(ti.try_number, 3)
# third run -- up for retry
run_with_error(ti)
self.assertEqual(ti.state, State.UP_FOR_RETRY)
self.assertEqual(ti._try_number, 3)
self.assertEqual(ti.try_number, 4)
# Clear the TI state since you can't run a task with a FAILED state without
# clearing it first
dag.clear()
# fourth run -- fail
run_with_error(ti)
ti.refresh_from_db()
self.assertEqual(ti.state, State.FAILED)
self.assertEqual(ti._try_number, 4)
self.assertEqual(ti.try_number, 5)
self.assertEqual(RenderedTaskInstanceFields.get_templated_fields(ti), expected_rendered_ti_fields)
# third run -- up for retry
run_with_error(ti)
self.assertEqual(ti.state, State.UP_FOR_RETRY)
self.assertEqual(ti._try_number, 3)
self.assertEqual(ti.try_number, 4)
# fourth run -- fail
run_with_error(ti)
ti.refresh_from_db()
self.assertEqual(ti.state, State.FAILED)
self.assertEqual(ti._try_number, 4)
self.assertEqual(ti.try_number, 5)
self.assertEqual(RenderedTaskInstanceFields.get_templated_fields(ti), expected_rendered_ti_fields)
def test_next_retry_datetime(self):
delay = datetime.timedelta(seconds=30)
@ -1650,13 +1647,7 @@ class TestTaskInstance(unittest.TestCase):
execution_date=DEFAULT_DATE, mark_success=True)
assert assert_command == generate_command
@parameterized.expand([
(True, ),
(False, )
])
def test_get_rendered_template_fields(self, store_serialized_dag):
# SetUp
settings.STORE_SERIALIZED_DAGS = store_serialized_dag
def test_get_rendered_template_fields(self):
with DAG('test-dag', start_date=DEFAULT_DATE):
task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
@ -1673,7 +1664,6 @@ class TestTaskInstance(unittest.TestCase):
new_ti = TI(task=new_task, execution_date=DEFAULT_DATE)
new_ti.get_rendered_template_fields()
self.assertEqual(settings.STORE_SERIALIZED_DAGS, store_serialized_dag)
self.assertEqual("op1", ti.task.bash_command)
# CleanUp
@ -1726,7 +1716,7 @@ class TestRunRawTaskQueriesCount(unittest.TestCase):
@parameterized.expand([
# Expected queries, mark_success
(7, False),
(10, False),
(5, True),
])
def test_execute_queries_count(self, expected_query_count, mark_success):
@ -1748,9 +1738,7 @@ class TestRunRawTaskQueriesCount(unittest.TestCase):
ti.state = State.RUNNING
session.merge(ti)
with assert_queries_count(10), patch(
"airflow.models.taskinstance.STORE_SERIALIZED_DAGS", True
):
with assert_queries_count(10):
ti._run_raw_task()
def test_operator_field_with_serialization(self):

Просмотреть файл

@ -22,7 +22,8 @@ from datetime import datetime
from unittest import TestCase
from airflow.exceptions import DagRunAlreadyExists
from airflow.models import DAG, DagModel, DagRun, Log, TaskInstance
from airflow.models import DAG, DagBag, DagModel, DagRun, Log, TaskInstance
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.dagrun_operator import TriggerDagRunOperator
from airflow.utils import timezone
from airflow.utils.session import create_session
@ -57,13 +58,16 @@ class TestDagRunOperator(TestCase):
session.commit()
self.dag = DAG(TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE})
dagbag = DagBag(f.name, read_dags_from_db=False, include_examples=False)
dagbag.bag_dag(self.dag, root_dag=self.dag)
dagbag.sync_to_db()
def tearDown(self):
"""Cleanup state after testing in DB."""
with create_session() as session:
session.query(Log).filter(Log.dag_id == TEST_DAG_ID).delete(synchronize_session=False)
for dbmodel in [DagModel, DagRun, TaskInstance]:
session.query(dbmodel).filter(dbmodel.dag_id == TRIGGERED_DAG_ID).delete(
for dbmodel in [DagModel, DagRun, TaskInstance, SerializedDagModel]:
session.query(dbmodel).filter(dbmodel.dag_id.in_([TRIGGERED_DAG_ID, TEST_DAG_ID])).delete(
synchronize_session=False
)

Просмотреть файл

@ -192,7 +192,7 @@ class TestDagFileProcessorManager(unittest.TestCase):
pickle_dags=False,
async_mode=True)
dagbag = DagBag(TEST_DAG_FOLDER)
dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False)
with create_session() as session:
session.query(LJ).delete()
dag = dagbag.get_dag('example_branch_operator')
@ -234,7 +234,7 @@ class TestDagFileProcessorManager(unittest.TestCase):
test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py')
with conf_vars({('scheduler', 'max_threads'): '1',
('core', 'load_examples'): 'False'}):
dagbag = DagBag(test_dag_path)
dagbag = DagBag(test_dag_path, read_dags_from_db=False)
with create_session() as session:
session.query(LJ).delete()
dag = dagbag.get_dag('test_example_bash_operator')

Просмотреть файл

@ -18,22 +18,14 @@
import json
import unittest
from parameterized import parameterized_class
from airflow.api.common.experimental.trigger_dag import trigger_dag
from airflow.models import DagBag, DagRun
from airflow.models.serialized_dag import SerializedDagModel
from airflow.settings import Session
from airflow.www import app as application
from tests.test_utils.config import conf_vars
@parameterized_class([
{"dag_serialization": "False"},
{"dag_serialization": "True"},
])
class TestDagRunsEndpoint(unittest.TestCase):
dag_serialization = "False"
@classmethod
def setUpClass(cls):
@ -60,101 +52,83 @@ class TestDagRunsEndpoint(unittest.TestCase):
super().tearDown()
def test_get_dag_runs_success(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/dag_runs'
dag_id = 'example_bash_operator'
# Create DagRun
dag_run = trigger_dag(
dag_id=dag_id, run_id='test_get_dag_runs_success')
url_template = '/api/experimental/dags/{}/dag_runs'
dag_id = 'example_bash_operator'
# Create DagRun
dag_run = trigger_dag(
dag_id=dag_id, run_id='test_get_dag_runs_success')
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
data = json.loads(response.data.decode('utf-8'))
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
data = json.loads(response.data.decode('utf-8'))
self.assertIsInstance(data, list)
self.assertEqual(len(data), 1)
self.assertEqual(data[0]['dag_id'], dag_id)
self.assertEqual(data[0]['id'], dag_run.id)
self.assertIsInstance(data, list)
self.assertEqual(len(data), 1)
self.assertEqual(data[0]['dag_id'], dag_id)
self.assertEqual(data[0]['id'], dag_run.id)
def test_get_dag_runs_success_with_state_parameter(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/dag_runs?state=running'
dag_id = 'example_bash_operator'
# Create DagRun
dag_run = trigger_dag(
dag_id=dag_id, run_id='test_get_dag_runs_success')
url_template = '/api/experimental/dags/{}/dag_runs?state=running'
dag_id = 'example_bash_operator'
# Create DagRun
dag_run = trigger_dag(
dag_id=dag_id, run_id='test_get_dag_runs_success')
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
data = json.loads(response.data.decode('utf-8'))
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
data = json.loads(response.data.decode('utf-8'))
self.assertIsInstance(data, list)
self.assertEqual(len(data), 1)
self.assertEqual(data[0]['dag_id'], dag_id)
self.assertEqual(data[0]['id'], dag_run.id)
self.assertIsInstance(data, list)
self.assertEqual(len(data), 1)
self.assertEqual(data[0]['dag_id'], dag_id)
self.assertEqual(data[0]['id'], dag_run.id)
def test_get_dag_runs_success_with_capital_state_parameter(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/dag_runs?state=RUNNING'
dag_id = 'example_bash_operator'
# Create DagRun
dag_run = trigger_dag(
dag_id=dag_id, run_id='test_get_dag_runs_success')
url_template = '/api/experimental/dags/{}/dag_runs?state=RUNNING'
dag_id = 'example_bash_operator'
# Create DagRun
dag_run = trigger_dag(
dag_id=dag_id, run_id='test_get_dag_runs_success')
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
data = json.loads(response.data.decode('utf-8'))
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
data = json.loads(response.data.decode('utf-8'))
self.assertIsInstance(data, list)
self.assertEqual(len(data), 1)
self.assertEqual(data[0]['dag_id'], dag_id)
self.assertEqual(data[0]['id'], dag_run.id)
self.assertIsInstance(data, list)
self.assertEqual(len(data), 1)
self.assertEqual(data[0]['dag_id'], dag_id)
self.assertEqual(data[0]['id'], dag_run.id)
def test_get_dag_runs_success_with_state_no_result(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/dag_runs?state=dummy'
dag_id = 'example_bash_operator'
# Create DagRun
trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success')
url_template = '/api/experimental/dags/{}/dag_runs?state=dummy'
dag_id = 'example_bash_operator'
# Create DagRun
trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success')
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
data = json.loads(response.data.decode('utf-8'))
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
data = json.loads(response.data.decode('utf-8'))
self.assertIsInstance(data, list)
self.assertEqual(len(data), 0)
self.assertIsInstance(data, list)
self.assertEqual(len(data), 0)
def test_get_dag_runs_invalid_dag_id(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/dag_runs'
dag_id = 'DUMMY_DAG'
url_template = '/api/experimental/dags/{}/dag_runs'
dag_id = 'DUMMY_DAG'
response = self.app.get(url_template.format(dag_id))
self.assertEqual(400, response.status_code)
data = json.loads(response.data.decode('utf-8'))
response = self.app.get(url_template.format(dag_id))
self.assertEqual(400, response.status_code)
data = json.loads(response.data.decode('utf-8'))
self.assertNotIsInstance(data, list)
self.assertNotIsInstance(data, list)
def test_get_dag_runs_no_runs(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/dag_runs'
dag_id = 'example_bash_operator'
url_template = '/api/experimental/dags/{}/dag_runs'
dag_id = 'example_bash_operator'
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
data = json.loads(response.data.decode('utf-8'))
response = self.app.get(url_template.format(dag_id))
self.assertEqual(200, response.status_code)
data = json.loads(response.data.decode('utf-8'))
self.assertIsInstance(data, list)
self.assertEqual(len(data), 0)
self.assertIsInstance(data, list)
self.assertEqual(len(data), 0)

Просмотреть файл

@ -22,8 +22,6 @@ from datetime import timedelta
from unittest import mock
from urllib.parse import quote_plus
from parameterized import parameterized_class
from airflow import settings
from airflow.api.common.experimental.trigger_dag import trigger_dag
from airflow.models import DagBag, DagRun, Pool, TaskInstance
@ -32,7 +30,6 @@ from airflow.settings import Session
from airflow.utils.timezone import datetime, parse as parse_datetime, utcnow
from airflow.version import version
from airflow.www import app as application
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_pools
ROOT_FOLDER = os.path.realpath(
@ -61,12 +58,7 @@ class TestBase(unittest.TestCase):
)
@parameterized_class([
{"dag_serialization": "False"},
{"dag_serialization": "True"},
])
class TestApiExperimental(TestBase):
dag_serialization = "False"
@classmethod
def setUpClass(cls):
@ -100,287 +92,261 @@ class TestApiExperimental(TestBase):
self.assert_deprecated(resp_raw)
def test_task_info(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/tasks/{}'
url_template = '/api/experimental/dags/{}/tasks/{}'
response = self.client.get(
url_template.format('example_bash_operator', 'runme_0')
)
self.assert_deprecated(response)
response = self.client.get(
url_template.format('example_bash_operator', 'runme_0')
)
self.assert_deprecated(response)
self.assertIn('"email"', response.data.decode('utf-8'))
self.assertNotIn('error', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
self.assertIn('"email"', response.data.decode('utf-8'))
self.assertNotIn('error', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
response = self.client.get(
url_template.format('example_bash_operator', 'DNE')
)
self.assertIn('error', response.data.decode('utf-8'))
self.assertEqual(404, response.status_code)
response = self.client.get(
url_template.format('example_bash_operator', 'DNE')
)
self.assertIn('error', response.data.decode('utf-8'))
self.assertEqual(404, response.status_code)
response = self.client.get(
url_template.format('DNE', 'DNE')
)
self.assertIn('error', response.data.decode('utf-8'))
self.assertEqual(404, response.status_code)
response = self.client.get(
url_template.format('DNE', 'DNE')
)
self.assertIn('error', response.data.decode('utf-8'))
self.assertEqual(404, response.status_code)
def test_get_dag_code(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/code'
url_template = '/api/experimental/dags/{}/code'
response = self.client.get(
url_template.format('example_bash_operator')
)
self.assert_deprecated(response)
self.assertIn('BashOperator(', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
response = self.client.get(
url_template.format('example_bash_operator')
)
self.assert_deprecated(response)
self.assertIn('BashOperator(', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
response = self.client.get(
url_template.format('xyz')
)
self.assertEqual(404, response.status_code)
response = self.client.get(
url_template.format('xyz')
)
self.assertEqual(404, response.status_code)
def test_dag_paused(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
pause_url_template = '/api/experimental/dags/{}/paused/{}'
paused_url_template = '/api/experimental/dags/{}/paused'
paused_url = paused_url_template.format('example_bash_operator')
pause_url_template = '/api/experimental/dags/{}/paused/{}'
paused_url_template = '/api/experimental/dags/{}/paused'
paused_url = paused_url_template.format('example_bash_operator')
response = self.client.get(
pause_url_template.format('example_bash_operator', 'true')
)
self.assert_deprecated(response)
self.assertIn('ok', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
response = self.client.get(
pause_url_template.format('example_bash_operator', 'true')
)
self.assert_deprecated(response)
self.assertIn('ok', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
paused_response = self.client.get(paused_url)
paused_response = self.client.get(paused_url)
self.assertEqual(200, paused_response.status_code)
self.assertEqual({"is_paused": True}, paused_response.json)
self.assertEqual(200, paused_response.status_code)
self.assertEqual({"is_paused": True}, paused_response.json)
response = self.client.get(
pause_url_template.format('example_bash_operator', 'false')
)
self.assertIn('ok', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
response = self.client.get(
pause_url_template.format('example_bash_operator', 'false')
)
self.assertIn('ok', response.data.decode('utf-8'))
self.assertEqual(200, response.status_code)
paused_response = self.client.get(paused_url)
paused_response = self.client.get(paused_url)
self.assertEqual(200, paused_response.status_code)
self.assertEqual({"is_paused": False}, paused_response.json)
self.assertEqual(200, paused_response.status_code)
self.assertEqual({"is_paused": False}, paused_response.json)
def test_trigger_dag(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/dag_runs'
run_id = 'my_run' + utcnow().isoformat()
response = self.client.post(
url_template.format('example_bash_operator'),
data=json.dumps({'run_id': run_id}),
content_type="application/json"
)
self.assert_deprecated(response)
url_template = '/api/experimental/dags/{}/dag_runs'
run_id = 'my_run' + utcnow().isoformat()
response = self.client.post(
url_template.format('example_bash_operator'),
data=json.dumps({'run_id': run_id}),
content_type="application/json"
)
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
response_execution_date = parse_datetime(
json.loads(response.data.decode('utf-8'))['execution_date'])
self.assertEqual(0, response_execution_date.microsecond)
self.assertEqual(200, response.status_code)
response_execution_date = parse_datetime(
json.loads(response.data.decode('utf-8'))['execution_date'])
self.assertEqual(0, response_execution_date.microsecond)
# Check execution_date is correct
response = json.loads(response.data.decode('utf-8'))
dagbag = DagBag()
dag = dagbag.get_dag('example_bash_operator')
dag_run = dag.get_dagrun(response_execution_date)
dag_run_id = dag_run.run_id
self.assertEqual(run_id, dag_run_id)
self.assertEqual(dag_run_id, response['run_id'])
# Check execution_date is correct
response = json.loads(response.data.decode('utf-8'))
dagbag = DagBag()
dag = dagbag.get_dag('example_bash_operator')
dag_run = dag.get_dagrun(response_execution_date)
dag_run_id = dag_run.run_id
self.assertEqual(run_id, dag_run_id)
self.assertEqual(dag_run_id, response['run_id'])
# Test error for nonexistent dag
response = self.client.post(
url_template.format('does_not_exist_dag'),
data=json.dumps({}),
content_type="application/json"
)
self.assertEqual(404, response.status_code)
# Test error for nonexistent dag
response = self.client.post(
url_template.format('does_not_exist_dag'),
data=json.dumps({}),
content_type="application/json"
)
self.assertEqual(404, response.status_code)
def test_trigger_dag_for_date(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/dag_runs'
dag_id = 'example_bash_operator'
execution_date = utcnow() + timedelta(hours=1)
datetime_string = execution_date.isoformat()
url_template = '/api/experimental/dags/{}/dag_runs'
dag_id = 'example_bash_operator'
execution_date = utcnow() + timedelta(hours=1)
datetime_string = execution_date.isoformat()
# Test correct execution with execution date
response = self.client.post(
url_template.format(dag_id),
data=json.dumps({'execution_date': datetime_string}),
content_type="application/json"
)
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
self.assertEqual(datetime_string, json.loads(response.data.decode('utf-8'))['execution_date'])
# Test correct execution with execution date
response = self.client.post(
url_template.format(dag_id),
data=json.dumps({'execution_date': datetime_string}),
content_type="application/json"
)
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
self.assertEqual(datetime_string, json.loads(response.data.decode('utf-8'))['execution_date'])
dagbag = DagBag()
dag = dagbag.get_dag(dag_id)
dag_run = dag.get_dagrun(execution_date)
self.assertTrue(dag_run,
'Dag Run not found for execution date {}'
.format(execution_date))
dagbag = DagBag()
dag = dagbag.get_dag(dag_id)
dag_run = dag.get_dagrun(execution_date)
self.assertTrue(dag_run,
'Dag Run not found for execution date {}'
.format(execution_date))
# Test correct execution with execution date and microseconds replaced
response = self.client.post(
url_template.format(dag_id),
data=json.dumps({'execution_date': datetime_string, 'replace_microseconds': 'true'}),
content_type="application/json"
)
self.assertEqual(200, response.status_code)
response_execution_date = parse_datetime(
json.loads(response.data.decode('utf-8'))['execution_date'])
self.assertEqual(0, response_execution_date.microsecond)
# Test correct execution with execution date and microseconds replaced
response = self.client.post(
url_template.format(dag_id),
data=json.dumps({'execution_date': datetime_string, 'replace_microseconds': 'true'}),
content_type="application/json"
)
self.assertEqual(200, response.status_code)
response_execution_date = parse_datetime(
json.loads(response.data.decode('utf-8'))['execution_date'])
self.assertEqual(0, response_execution_date.microsecond)
dagbag = DagBag()
dag = dagbag.get_dag(dag_id)
dag_run = dag.get_dagrun(response_execution_date)
self.assertTrue(dag_run,
'Dag Run not found for execution date {}'
.format(execution_date))
dagbag = DagBag()
dag = dagbag.get_dag(dag_id)
dag_run = dag.get_dagrun(response_execution_date)
self.assertTrue(dag_run,
'Dag Run not found for execution date {}'
.format(execution_date))
# Test error for nonexistent dag
response = self.client.post(
url_template.format('does_not_exist_dag'),
data=json.dumps({'execution_date': datetime_string}),
content_type="application/json"
)
self.assertEqual(404, response.status_code)
# Test error for nonexistent dag
response = self.client.post(
url_template.format('does_not_exist_dag'),
data=json.dumps({'execution_date': datetime_string}),
content_type="application/json"
)
self.assertEqual(404, response.status_code)
# Test error for bad datetime format
response = self.client.post(
url_template.format(dag_id),
data=json.dumps({'execution_date': 'not_a_datetime'}),
content_type="application/json"
)
self.assertEqual(400, response.status_code)
# Test error for bad datetime format
response = self.client.post(
url_template.format(dag_id),
data=json.dumps({'execution_date': 'not_a_datetime'}),
content_type="application/json"
)
self.assertEqual(400, response.status_code)
def test_task_instance_info(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/dag_runs/{}/tasks/{}'
dag_id = 'example_bash_operator'
task_id = 'also_run_this'
execution_date = utcnow().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat())
wrong_datetime_string = quote_plus(
datetime(1990, 1, 1, 1, 1, 1).isoformat()
)
url_template = '/api/experimental/dags/{}/dag_runs/{}/tasks/{}'
dag_id = 'example_bash_operator'
task_id = 'also_run_this'
execution_date = utcnow().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat())
wrong_datetime_string = quote_plus(
datetime(1990, 1, 1, 1, 1, 1).isoformat()
)
# Create DagRun
trigger_dag(dag_id=dag_id,
run_id='test_task_instance_info_run',
execution_date=execution_date)
# Create DagRun
trigger_dag(dag_id=dag_id,
run_id='test_task_instance_info_run',
execution_date=execution_date)
# Test Correct execution
response = self.client.get(
url_template.format(dag_id, datetime_string, task_id)
)
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
self.assertIn('state', response.data.decode('utf-8'))
self.assertNotIn('error', response.data.decode('utf-8'))
# Test Correct execution
response = self.client.get(
url_template.format(dag_id, datetime_string, task_id)
)
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
self.assertIn('state', response.data.decode('utf-8'))
self.assertNotIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag
response = self.client.get(
url_template.format('does_not_exist_dag', datetime_string,
task_id),
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag
response = self.client.get(
url_template.format('does_not_exist_dag', datetime_string,
task_id),
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent task
response = self.client.get(
url_template.format(dag_id, datetime_string, 'does_not_exist_task')
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent task
response = self.client.get(
url_template.format(dag_id, datetime_string, 'does_not_exist_task')
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag run (wrong execution_date)
response = self.client.get(
url_template.format(dag_id, wrong_datetime_string, task_id)
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag run (wrong execution_date)
response = self.client.get(
url_template.format(dag_id, wrong_datetime_string, task_id)
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for bad datetime format
response = self.client.get(
url_template.format(dag_id, 'not_a_datetime', task_id)
)
self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for bad datetime format
response = self.client.get(
url_template.format(dag_id, 'not_a_datetime', task_id)
)
self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
def test_dagrun_status(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/dags/{}/dag_runs/{}'
dag_id = 'example_bash_operator'
execution_date = utcnow().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat())
wrong_datetime_string = quote_plus(
datetime(1990, 1, 1, 1, 1, 1).isoformat()
)
url_template = '/api/experimental/dags/{}/dag_runs/{}'
dag_id = 'example_bash_operator'
execution_date = utcnow().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat())
wrong_datetime_string = quote_plus(
datetime(1990, 1, 1, 1, 1, 1).isoformat()
)
# Create DagRun
trigger_dag(dag_id=dag_id,
run_id='test_task_instance_info_run',
execution_date=execution_date)
# Create DagRun
trigger_dag(dag_id=dag_id,
run_id='test_task_instance_info_run',
execution_date=execution_date)
# Test Correct execution
response = self.client.get(
url_template.format(dag_id, datetime_string)
)
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
self.assertIn('state', response.data.decode('utf-8'))
self.assertNotIn('error', response.data.decode('utf-8'))
# Test Correct execution
response = self.client.get(
url_template.format(dag_id, datetime_string)
)
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
self.assertIn('state', response.data.decode('utf-8'))
self.assertNotIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag
response = self.client.get(
url_template.format('does_not_exist_dag', datetime_string),
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag
response = self.client.get(
url_template.format('does_not_exist_dag', datetime_string),
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag run (wrong execution_date)
response = self.client.get(
url_template.format(dag_id, wrong_datetime_string)
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag run (wrong execution_date)
response = self.client.get(
url_template.format(dag_id, wrong_datetime_string)
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for bad datetime format
response = self.client.get(
url_template.format(dag_id, 'not_a_datetime')
)
self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for bad datetime format
response = self.client.get(
url_template.format(dag_id, 'not_a_datetime')
)
self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
@parameterized_class([
{"dag_serialization": "False"},
{"dag_serialization": "True"},
])
class TestLineageApiExperimental(TestBase):
PAPERMILL_EXAMPLE_DAGS = os.path.join(ROOT_FOLDER, "airflow", "providers", "papermill", "example_dags")
dag_serialization = "False"
@classmethod
def setUpClass(cls):
@ -398,51 +364,48 @@ class TestLineageApiExperimental(TestBase):
@mock.patch("airflow.settings.DAGS_FOLDER", PAPERMILL_EXAMPLE_DAGS)
def test_lineage_info(self):
with conf_vars(
{("core", "store_serialized_dags"): self.dag_serialization}
):
url_template = '/api/experimental/lineage/{}/{}'
dag_id = 'example_papermill_operator'
execution_date = utcnow().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat())
wrong_datetime_string = quote_plus(
datetime(1990, 1, 1, 1, 1, 1).isoformat()
)
url_template = '/api/experimental/lineage/{}/{}'
dag_id = 'example_papermill_operator'
execution_date = utcnow().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat())
wrong_datetime_string = quote_plus(
datetime(1990, 1, 1, 1, 1, 1).isoformat()
)
# create DagRun
trigger_dag(dag_id=dag_id,
run_id='test_lineage_info_run',
execution_date=execution_date)
# create DagRun
trigger_dag(dag_id=dag_id,
run_id='test_lineage_info_run',
execution_date=execution_date)
# test correct execution
response = self.client.get(
url_template.format(dag_id, datetime_string)
)
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
self.assertIn('task_ids', response.data.decode('utf-8'))
self.assertNotIn('error', response.data.decode('utf-8'))
# test correct execution
response = self.client.get(
url_template.format(dag_id, datetime_string)
)
self.assert_deprecated(response)
self.assertEqual(200, response.status_code)
self.assertIn('task_ids', response.data.decode('utf-8'))
self.assertNotIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag
response = self.client.get(
url_template.format('does_not_exist_dag', datetime_string),
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag
response = self.client.get(
url_template.format('does_not_exist_dag', datetime_string),
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag run (wrong execution_date)
response = self.client.get(
url_template.format(dag_id, wrong_datetime_string)
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for nonexistent dag run (wrong execution_date)
response = self.client.get(
url_template.format(dag_id, wrong_datetime_string)
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for bad datetime format
response = self.client.get(
url_template.format(dag_id, 'not_a_datetime')
)
self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
# Test error for bad datetime format
response = self.client.get(
url_template.format(dag_id, 'not_a_datetime')
)
self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
class TestPoolApiExperimental(TestBase):

Просмотреть файл

@ -462,9 +462,9 @@ class TestAirflowBaseViews(TestBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.dagbag = models.DagBag(include_examples=True)
models.DagBag(include_examples=True).sync_to_db()
cls.dagbag = models.DagBag(include_examples=True, read_dags_from_db=True)
cls.app.dag_bag = cls.dagbag
DAG.bulk_write_to_db(cls.dagbag.dags.values())
def setUp(self):
super().setUp()
@ -474,9 +474,9 @@ class TestAirflowBaseViews(TestBase):
self.prepare_dagruns()
def prepare_dagruns(self):
self.bash_dag = self.dagbag.dags['example_bash_operator']
self.sub_dag = self.dagbag.dags['example_subdag_operator']
self.xcom_dag = self.dagbag.dags['example_xcom']
self.bash_dag = self.dagbag.get_dag('example_bash_operator')
self.sub_dag = self.dagbag.get_dag('example_subdag_operator')
self.xcom_dag = self.dagbag.get_dag('example_xcom')
self.bash_dagrun = self.bash_dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
@ -699,7 +699,7 @@ class TestAirflowBaseViews(TestBase):
("\"", r'\"conf\":{\"abc\":\"\\\"\"}'),
])
def test_escape_in_tree_view(self, test_str, expected_text):
dag = self.dagbag.dags['test_tree_view']
dag = self.dagbag.get_dag('test_tree_view')
dag.create_dagrun(
execution_date=self.EXAMPLE_DAG_DEFAULT_DATE,
start_date=timezone.utcnow(),
@ -713,7 +713,7 @@ class TestAirflowBaseViews(TestBase):
self.check_content_in_response(expected_text, resp)
def test_dag_details_trigger_origin_tree_view(self):
dag = self.dagbag.dags['test_tree_view']
dag = self.dagbag.get_dag('test_tree_view')
dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=self.EXAMPLE_DAG_DEFAULT_DATE,
@ -727,7 +727,7 @@ class TestAirflowBaseViews(TestBase):
self.check_content_in_response(href, resp)
def test_dag_details_trigger_origin_graph_view(self):
dag = self.dagbag.dags['test_graph_view']
dag = self.dagbag.get_dag('test_graph_view')
dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=self.EXAMPLE_DAG_DEFAULT_DATE,
@ -818,7 +818,8 @@ class TestAirflowBaseViews(TestBase):
url = 'code?dag_id=example_bash_operator'
mock_open_patch = mock.mock_open(read_data='')
mock_open_patch.side_effect = FileNotFoundError
with mock.patch('builtins.open', mock_open_patch):
with mock.patch('builtins.open', mock_open_patch), \
mock.patch("airflow.models.dagcode.STORE_DAG_CODE", False):
resp = self.client.get(url, follow_redirects=True)
self.check_content_in_response('Failed to load file', resp)
self.check_content_in_response('example_bash_operator', resp)
@ -1030,23 +1031,11 @@ class TestAirflowBaseViews(TestBase):
resp = self.client.post('refresh?dag_id=example_bash_operator')
self.check_content_in_response('', resp, resp_code=302)
@parameterized.expand([(True,), (False,)])
def test_refresh_all(self, dag_serialization):
with mock.patch('airflow.www.views.settings.STORE_SERIALIZED_DAGS', dag_serialization):
if dag_serialization:
with mock.patch.object(
self.app.dag_bag, 'collect_dags_from_db'
) as collect_dags_from_db:
resp = self.client.post("/refresh_all", follow_redirects=True)
self.check_content_in_response('', resp)
collect_dags_from_db.assert_called_once_with()
else:
with mock.patch.object(
self.app.dag_bag, 'collect_dags'
) as collect_dags:
resp = self.client.post("/refresh_all", follow_redirects=True)
self.check_content_in_response('', resp)
collect_dags.assert_called_once_with(only_if_updated=False)
def test_refresh_all(self):
with mock.patch.object(self.app.dag_bag, 'collect_dags_from_db') as collect_dags_from_db:
resp = self.client.post("/refresh_all", follow_redirects=True)
self.check_content_in_response('', resp)
collect_dags_from_db.assert_called_once_with()
def test_delete_dag_button_normal(self):
resp = self.client.get('/', follow_redirects=True)
@ -1181,10 +1170,16 @@ class TestLogView(TestBase):
dagbag = self.app.dag_bag
dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
dag.sync_to_db()
dag_removed = DAG(self.DAG_ID_REMOVED, start_date=self.DEFAULT_DATE)
dag_removed.sync_to_db()
dagbag.bag_dag(dag=dag, root_dag=dag)
dagbag.bag_dag(dag=dag_removed, root_dag=dag_removed)
# Since we don't want to store the code for the DAG defined in this file
with mock.patch.object(settings, "STORE_DAG_CODE", False):
dag.sync_to_db()
dag_removed.sync_to_db()
dagbag.sync_to_db()
with create_session() as session:
self.ti = TaskInstance(
task=DummyOperator(task_id=self.TASK_ID, dag=dag),
@ -1711,9 +1706,9 @@ class TestDagACLView(TestBase):
cls.appbuilder.sm.del_register_user(user)
def prepare_dagruns(self):
dagbag = models.DagBag(include_examples=True)
self.bash_dag = dagbag.dags['example_bash_operator']
self.sub_dag = dagbag.dags['example_subdag_operator']
dagbag = models.DagBag(include_examples=True, read_dags_from_db=True)
self.bash_dag = dagbag.get_dag("example_bash_operator")
self.sub_dag = dagbag.get_dag("example_subdag_operator")
self.bash_dagrun = self.bash_dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
@ -2414,7 +2409,6 @@ class TestRenderedView(TestBase):
resp = self.client.get(url, follow_redirects=True)
self.check_content_in_response("testdag__task1__20200301", resp)
@mock.patch('airflow.models.taskinstance.STORE_SERIALIZED_DAGS', True)
def test_user_defined_filter_and_macros_raise_error(self):
"""
Test that the Rendered View is able to show rendered values
@ -2959,11 +2953,11 @@ class TestDecorators(TestBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
dagbag = models.DagBag(include_examples=True)
DAG.bulk_write_to_db(dagbag.dags.values())
cls.bash_dag = dagbag.dags['example_bash_operator']
cls.sub_dag = dagbag.dags['example_subdag_operator']
cls.xcom_dag = dagbag.dags['example_xcom']
models.DagBag(include_examples=True, read_dags_from_db=False).sync_to_db()
dagbag = models.DagBag(include_examples=True, read_dags_from_db=True)
cls.bash_dag = dagbag.get_dag('example_bash_operator')
cls.sub_dag = dagbag.get_dag('example_subdag_operator')
cls.xcom_dag = dagbag.get_dag('example_xcom')
def setUp(self):
super().setUp()