[AIRFLOW-192] Add weight_rule param to BaseOperator

Improved task generation performance significantly
by using sets of
task_ids and dag_ids instead of lists when
calculating total priority
weight.

Closes #2941 from wongwill86/performance-latest
This commit is contained in:
wongwill86 2018-01-18 16:09:40 +01:00 коммит произвёл Bolke de Bruin
Родитель fbba5ef7c3
Коммит dd2bc8cb97
3 изменённых файлов: 210 добавлений и 18 удалений

Просмотреть файл

@ -19,7 +19,6 @@ from __future__ import unicode_literals
from future.standard_library import install_aliases
install_aliases()
from builtins import str
from builtins import object, bytes
import copy
@ -84,8 +83,11 @@ from airflow.utils.operator_resources import Resources
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
from airflow.utils.log.logging_mixin import LoggingMixin
install_aliases()
Base = declarative_base()
ID_LEN = 250
XCOM_RETURN_KEY = 'return_value'
@ -2073,6 +2075,29 @@ class BaseOperator(LoggingMixin):
This allows the executor to trigger higher priority tasks before
others when things get backed up.
:type priority_weight: int
:param weight_rule: weighting method used for the effective total
priority weight of the task. Options are:
``{ downstream | upstream | absolute }`` default is ``downstream``
When set to ``downstream`` the effective weight of the task is the
aggregate sum of all downstream descendants. As a result, upstream
tasks will have higher weight and will be scheduled more aggressively
when using positive weight values. This is useful when you have
multiple dag run instances and desire to have all upstream tasks to
complete for all runs before each dag can continue processing
downstream tasks. When set to ``upstream`` the effective weight is the
aggregate sum of all upstream ancestors. This is the opposite where
downtream tasks have higher weight and will be scheduled more
aggressively when using positive weight values. This is useful when you
have multiple dag run instances and prefer to have each dag complete
before starting upstream tasks of other dags. When set to
``absolute``, the effective weight is the exact ``priority_weight``
specified without additional weighting. You may want to do this when
you know exactly what priority weight each task should have.
Additionally, when set to ``absolute``, there is bonus effect of
significantly speeding up the task creation process as for very large
DAGS. Options can be set as string or using the constants defined in
the static class ``airflow.utils.WeightRule``
:type weight_rule: str
:param pool: the slot pool this task should run in, slot pools are a
way to limit concurrency for certain tasks
:type pool: str
@ -2150,6 +2175,7 @@ class BaseOperator(LoggingMixin):
default_args=None,
adhoc=False,
priority_weight=1,
weight_rule=WeightRule.DOWNSTREAM,
queue=configuration.get('celery', 'default_queue'),
pool=None,
sla=None,
@ -2190,7 +2216,7 @@ class BaseOperator(LoggingMixin):
"The trigger_rule must be one of {all_triggers},"
"'{d}.{t}'; received '{tr}'."
.format(all_triggers=TriggerRule.all_triggers,
d=dag.dag_id, t=task_id, tr=trigger_rule))
d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule))
self.trigger_rule = trigger_rule
self.depends_on_past = depends_on_past
@ -2224,6 +2250,14 @@ class BaseOperator(LoggingMixin):
self.params = params or {} # Available in templates!
self.adhoc = adhoc
self.priority_weight = priority_weight
if not WeightRule.is_valid(weight_rule):
raise AirflowException(
"The weight_rule must be one of {all_weight_rules},"
"'{d}.{t}'; received '{tr}'."
.format(all_weight_rules=WeightRule.all_weight_rules,
d=dag.dag_id if dag else "", t=task_id, tr=weight_rule))
self.weight_rule = weight_rule
self.resources = Resources(**(resources or {}))
self.run_as_user = run_as_user
self.task_concurrency = task_concurrency
@ -2402,10 +2436,19 @@ class BaseOperator(LoggingMixin):
@property
def priority_weight_total(self):
return sum([
t.priority_weight
for t in self.get_flat_relatives(upstream=False)
]) + self.priority_weight
if self.weight_rule == WeightRule.ABSOLUTE:
return self.priority_weight
elif self.weight_rule == WeightRule.DOWNSTREAM:
upstream = False
elif self.weight_rule == WeightRule.UPSTREAM:
upstream = True
else:
upstream = False
return self.priority_weight + sum(
map(lambda task_id: self._dag.task_dict[task_id].priority_weight,
self.get_flat_relative_ids(upstream=upstream))
)
def pre_execute(self, context):
"""
@ -2608,17 +2651,30 @@ class BaseOperator(LoggingMixin):
TI.execution_date <= end_date,
).order_by(TI.execution_date).all()
def get_flat_relatives(self, upstream=False, l=None):
def get_flat_relative_ids(self, upstream=False, found_descendants=None):
"""
Get a flat list of relatives' ids, either upstream or downstream.
"""
if not found_descendants:
found_descendants = set()
relative_ids = self.get_direct_relative_ids(upstream)
for relative_id in relative_ids:
if relative_id not in found_descendants:
found_descendants.add(relative_id)
relative_task = self._dag.task_dict[relative_id]
relative_task.get_flat_relative_ids(upstream,
found_descendants)
return found_descendants
def get_flat_relatives(self, upstream=False):
"""
Get a flat list of relatives, either upstream or downstream.
"""
if not l:
l = []
for t in self.get_direct_relatives(upstream):
if not is_in(t, l):
l.append(t)
t.get_flat_relatives(upstream, l)
return l
return list(map(lambda task_id: self._dag.task_dict[task_id],
self.get_flat_relative_ids(upstream)))
def detect_downstream_cycle(self, task=None):
"""
@ -2664,6 +2720,16 @@ class BaseOperator(LoggingMixin):
self.log.info('Rendering template for %s', attr)
self.log.info(content)
def get_direct_relative_ids(self, upstream=False):
"""
Get the direct relative ids to the current task, upstream or
downstream.
"""
if upstream:
return self._upstream_task_ids
else:
return self._downstream_task_ids
def get_direct_relatives(self, upstream=False):
"""
Get the direct relatives to the current task, upstream or
@ -2704,14 +2770,14 @@ class BaseOperator(LoggingMixin):
# relationships can only be set if the tasks share a single DAG. Tasks
# without a DAG are assigned to that DAG.
dags = set(t.dag for t in [self] + task_list if t.has_dag())
dags = {t._dag.dag_id: t.dag for t in [self] + task_list if t.has_dag()}
if len(dags) > 1:
raise AirflowException(
'Tried to set relationships between tasks in '
'more than one DAG: {}'.format(dags))
'more than one DAG: {}'.format(dags.values()))
elif len(dags) == 1:
dag = list(dags)[0]
dag = dags.popitem()[1]
else:
raise AirflowException(
"Tried to create relationships between tasks that don't have "
@ -4739,7 +4805,7 @@ class DagRun(Base, LoggingMixin):
ti.state = State.REMOVED
# check for missing tasks
for task in dag.tasks:
for task in six.itervalues(dag.task_dict):
if task.adhoc:
continue

Просмотреть файл

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import unicode_literals
from builtins import object
class WeightRule(object):
DOWNSTREAM = 'downstream'
UPSTREAM = 'upstream'
ABSOLUTE = 'absolute'
@classmethod
def is_valid(cls, weight_rule):
return weight_rule in cls.all_weight_rules()
@classmethod
def all_weight_rules(cls):
return [getattr(cls, attr)
for attr in dir(cls)
if not attr.startswith("__") and not callable(getattr(cls, attr))]

Просмотреть файл

@ -23,6 +23,8 @@ import os
import pendulum
import unittest
import time
import six
import re
from airflow import configuration, models, settings, AirflowException
from airflow.exceptions import AirflowSkipException
@ -39,6 +41,7 @@ from airflow.operators.python_operator import PythonOperator
from airflow.operators.python_operator import ShortCircuitOperator
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.weight_rule import WeightRule
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule
from mock import patch
@ -201,6 +204,96 @@ class DagTest(unittest.TestCase):
self.assertEquals(tuple(), dag.topological_sort())
def test_dag_task_priority_weight_total(self):
width = 5
depth = 5
weight = 5
pattern = re.compile('stage(\\d*).(\\d*)')
# Fully connected parallel tasks. i.e. every task at each parallel
# stage is dependent on every task in the previous stage.
# Default weight should be calculated using downstream descendants
with DAG('dag', start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'}) as dag:
pipeline = [
[DummyOperator(
task_id='stage{}.{}'.format(i, j), priority_weight=weight)
for j in range(0, width)] for i in range(0, depth)
]
for d, stage in enumerate(pipeline):
if d == 0:
continue
for current_task in stage:
for prev_task in pipeline[d - 1]:
current_task.set_upstream(prev_task)
for task in six.itervalues(dag.task_dict):
match = pattern.match(task.task_id)
task_depth = int(match.group(1))
# the sum of each stages after this task + itself
correct_weight = ((depth - (task_depth + 1)) * width + 1) * weight
calculated_weight = task.priority_weight_total
self.assertEquals(calculated_weight, correct_weight)
# Same test as above except use 'upstream' for weight calculation
weight = 3
with DAG('dag', start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'}) as dag:
pipeline = [
[DummyOperator(
task_id='stage{}.{}'.format(i, j), priority_weight=weight,
weight_rule=WeightRule.UPSTREAM)
for j in range(0, width)] for i in range(0, depth)
]
for d, stage in enumerate(pipeline):
if d == 0:
continue
for current_task in stage:
for prev_task in pipeline[d - 1]:
current_task.set_upstream(prev_task)
for task in six.itervalues(dag.task_dict):
match = pattern.match(task.task_id)
task_depth = int(match.group(1))
# the sum of each stages after this task + itself
correct_weight = ((task_depth) * width + 1) * weight
calculated_weight = task.priority_weight_total
self.assertEquals(calculated_weight, correct_weight)
# Same test as above except use 'absolute' for weight calculation
weight = 10
with DAG('dag', start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'}) as dag:
pipeline = [
[DummyOperator(
task_id='stage{}.{}'.format(i, j), priority_weight=weight,
weight_rule=WeightRule.ABSOLUTE)
for j in range(0, width)] for i in range(0, depth)
]
for d, stage in enumerate(pipeline):
if d == 0:
continue
for current_task in stage:
for prev_task in pipeline[d - 1]:
current_task.set_upstream(prev_task)
for task in six.itervalues(dag.task_dict):
match = pattern.match(task.task_id)
task_depth = int(match.group(1))
# the sum of each stages after this task + itself
correct_weight = weight
calculated_weight = task.priority_weight_total
self.assertEquals(calculated_weight, correct_weight)
# Test if we enter an invalid weight rule
with DAG('dag', start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'}) as dag:
with self.assertRaises(AirflowException):
DummyOperator(task_id='should_fail', weight_rule='no rule')
def test_get_num_task_instances(self):
test_dag_id = 'test_get_num_task_instances_dag'
test_task_id = 'task_1'