[AIRFLOW-192] Add weight_rule param to BaseOperator
Improved task generation performance significantly by using sets of task_ids and dag_ids instead of lists when calculating total priority weight. Closes #2941 from wongwill86/performance-latest
This commit is contained in:
Родитель
fbba5ef7c3
Коммит
dd2bc8cb97
|
@ -19,7 +19,6 @@ from __future__ import unicode_literals
|
|||
|
||||
from future.standard_library import install_aliases
|
||||
|
||||
install_aliases()
|
||||
from builtins import str
|
||||
from builtins import object, bytes
|
||||
import copy
|
||||
|
@ -84,8 +83,11 @@ from airflow.utils.operator_resources import Resources
|
|||
from airflow.utils.state import State
|
||||
from airflow.utils.timeout import timeout
|
||||
from airflow.utils.trigger_rule import TriggerRule
|
||||
from airflow.utils.weight_rule import WeightRule
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
|
||||
install_aliases()
|
||||
|
||||
Base = declarative_base()
|
||||
ID_LEN = 250
|
||||
XCOM_RETURN_KEY = 'return_value'
|
||||
|
@ -2073,6 +2075,29 @@ class BaseOperator(LoggingMixin):
|
|||
This allows the executor to trigger higher priority tasks before
|
||||
others when things get backed up.
|
||||
:type priority_weight: int
|
||||
:param weight_rule: weighting method used for the effective total
|
||||
priority weight of the task. Options are:
|
||||
``{ downstream | upstream | absolute }`` default is ``downstream``
|
||||
When set to ``downstream`` the effective weight of the task is the
|
||||
aggregate sum of all downstream descendants. As a result, upstream
|
||||
tasks will have higher weight and will be scheduled more aggressively
|
||||
when using positive weight values. This is useful when you have
|
||||
multiple dag run instances and desire to have all upstream tasks to
|
||||
complete for all runs before each dag can continue processing
|
||||
downstream tasks. When set to ``upstream`` the effective weight is the
|
||||
aggregate sum of all upstream ancestors. This is the opposite where
|
||||
downtream tasks have higher weight and will be scheduled more
|
||||
aggressively when using positive weight values. This is useful when you
|
||||
have multiple dag run instances and prefer to have each dag complete
|
||||
before starting upstream tasks of other dags. When set to
|
||||
``absolute``, the effective weight is the exact ``priority_weight``
|
||||
specified without additional weighting. You may want to do this when
|
||||
you know exactly what priority weight each task should have.
|
||||
Additionally, when set to ``absolute``, there is bonus effect of
|
||||
significantly speeding up the task creation process as for very large
|
||||
DAGS. Options can be set as string or using the constants defined in
|
||||
the static class ``airflow.utils.WeightRule``
|
||||
:type weight_rule: str
|
||||
:param pool: the slot pool this task should run in, slot pools are a
|
||||
way to limit concurrency for certain tasks
|
||||
:type pool: str
|
||||
|
@ -2150,6 +2175,7 @@ class BaseOperator(LoggingMixin):
|
|||
default_args=None,
|
||||
adhoc=False,
|
||||
priority_weight=1,
|
||||
weight_rule=WeightRule.DOWNSTREAM,
|
||||
queue=configuration.get('celery', 'default_queue'),
|
||||
pool=None,
|
||||
sla=None,
|
||||
|
@ -2190,7 +2216,7 @@ class BaseOperator(LoggingMixin):
|
|||
"The trigger_rule must be one of {all_triggers},"
|
||||
"'{d}.{t}'; received '{tr}'."
|
||||
.format(all_triggers=TriggerRule.all_triggers,
|
||||
d=dag.dag_id, t=task_id, tr=trigger_rule))
|
||||
d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule))
|
||||
|
||||
self.trigger_rule = trigger_rule
|
||||
self.depends_on_past = depends_on_past
|
||||
|
@ -2224,6 +2250,14 @@ class BaseOperator(LoggingMixin):
|
|||
self.params = params or {} # Available in templates!
|
||||
self.adhoc = adhoc
|
||||
self.priority_weight = priority_weight
|
||||
if not WeightRule.is_valid(weight_rule):
|
||||
raise AirflowException(
|
||||
"The weight_rule must be one of {all_weight_rules},"
|
||||
"'{d}.{t}'; received '{tr}'."
|
||||
.format(all_weight_rules=WeightRule.all_weight_rules,
|
||||
d=dag.dag_id if dag else "", t=task_id, tr=weight_rule))
|
||||
self.weight_rule = weight_rule
|
||||
|
||||
self.resources = Resources(**(resources or {}))
|
||||
self.run_as_user = run_as_user
|
||||
self.task_concurrency = task_concurrency
|
||||
|
@ -2402,10 +2436,19 @@ class BaseOperator(LoggingMixin):
|
|||
|
||||
@property
|
||||
def priority_weight_total(self):
|
||||
return sum([
|
||||
t.priority_weight
|
||||
for t in self.get_flat_relatives(upstream=False)
|
||||
]) + self.priority_weight
|
||||
if self.weight_rule == WeightRule.ABSOLUTE:
|
||||
return self.priority_weight
|
||||
elif self.weight_rule == WeightRule.DOWNSTREAM:
|
||||
upstream = False
|
||||
elif self.weight_rule == WeightRule.UPSTREAM:
|
||||
upstream = True
|
||||
else:
|
||||
upstream = False
|
||||
|
||||
return self.priority_weight + sum(
|
||||
map(lambda task_id: self._dag.task_dict[task_id].priority_weight,
|
||||
self.get_flat_relative_ids(upstream=upstream))
|
||||
)
|
||||
|
||||
def pre_execute(self, context):
|
||||
"""
|
||||
|
@ -2608,17 +2651,30 @@ class BaseOperator(LoggingMixin):
|
|||
TI.execution_date <= end_date,
|
||||
).order_by(TI.execution_date).all()
|
||||
|
||||
def get_flat_relatives(self, upstream=False, l=None):
|
||||
def get_flat_relative_ids(self, upstream=False, found_descendants=None):
|
||||
"""
|
||||
Get a flat list of relatives' ids, either upstream or downstream.
|
||||
"""
|
||||
|
||||
if not found_descendants:
|
||||
found_descendants = set()
|
||||
relative_ids = self.get_direct_relative_ids(upstream)
|
||||
|
||||
for relative_id in relative_ids:
|
||||
if relative_id not in found_descendants:
|
||||
found_descendants.add(relative_id)
|
||||
relative_task = self._dag.task_dict[relative_id]
|
||||
relative_task.get_flat_relative_ids(upstream,
|
||||
found_descendants)
|
||||
|
||||
return found_descendants
|
||||
|
||||
def get_flat_relatives(self, upstream=False):
|
||||
"""
|
||||
Get a flat list of relatives, either upstream or downstream.
|
||||
"""
|
||||
if not l:
|
||||
l = []
|
||||
for t in self.get_direct_relatives(upstream):
|
||||
if not is_in(t, l):
|
||||
l.append(t)
|
||||
t.get_flat_relatives(upstream, l)
|
||||
return l
|
||||
return list(map(lambda task_id: self._dag.task_dict[task_id],
|
||||
self.get_flat_relative_ids(upstream)))
|
||||
|
||||
def detect_downstream_cycle(self, task=None):
|
||||
"""
|
||||
|
@ -2664,6 +2720,16 @@ class BaseOperator(LoggingMixin):
|
|||
self.log.info('Rendering template for %s', attr)
|
||||
self.log.info(content)
|
||||
|
||||
def get_direct_relative_ids(self, upstream=False):
|
||||
"""
|
||||
Get the direct relative ids to the current task, upstream or
|
||||
downstream.
|
||||
"""
|
||||
if upstream:
|
||||
return self._upstream_task_ids
|
||||
else:
|
||||
return self._downstream_task_ids
|
||||
|
||||
def get_direct_relatives(self, upstream=False):
|
||||
"""
|
||||
Get the direct relatives to the current task, upstream or
|
||||
|
@ -2704,14 +2770,14 @@ class BaseOperator(LoggingMixin):
|
|||
|
||||
# relationships can only be set if the tasks share a single DAG. Tasks
|
||||
# without a DAG are assigned to that DAG.
|
||||
dags = set(t.dag for t in [self] + task_list if t.has_dag())
|
||||
dags = {t._dag.dag_id: t.dag for t in [self] + task_list if t.has_dag()}
|
||||
|
||||
if len(dags) > 1:
|
||||
raise AirflowException(
|
||||
'Tried to set relationships between tasks in '
|
||||
'more than one DAG: {}'.format(dags))
|
||||
'more than one DAG: {}'.format(dags.values()))
|
||||
elif len(dags) == 1:
|
||||
dag = list(dags)[0]
|
||||
dag = dags.popitem()[1]
|
||||
else:
|
||||
raise AirflowException(
|
||||
"Tried to create relationships between tasks that don't have "
|
||||
|
@ -4739,7 +4805,7 @@ class DagRun(Base, LoggingMixin):
|
|||
ti.state = State.REMOVED
|
||||
|
||||
# check for missing tasks
|
||||
for task in dag.tasks:
|
||||
for task in six.itervalues(dag.task_dict):
|
||||
if task.adhoc:
|
||||
continue
|
||||
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from builtins import object
|
||||
|
||||
|
||||
class WeightRule(object):
|
||||
DOWNSTREAM = 'downstream'
|
||||
UPSTREAM = 'upstream'
|
||||
ABSOLUTE = 'absolute'
|
||||
|
||||
@classmethod
|
||||
def is_valid(cls, weight_rule):
|
||||
return weight_rule in cls.all_weight_rules()
|
||||
|
||||
@classmethod
|
||||
def all_weight_rules(cls):
|
||||
return [getattr(cls, attr)
|
||||
for attr in dir(cls)
|
||||
if not attr.startswith("__") and not callable(getattr(cls, attr))]
|
|
@ -23,6 +23,8 @@ import os
|
|||
import pendulum
|
||||
import unittest
|
||||
import time
|
||||
import six
|
||||
import re
|
||||
|
||||
from airflow import configuration, models, settings, AirflowException
|
||||
from airflow.exceptions import AirflowSkipException
|
||||
|
@ -39,6 +41,7 @@ from airflow.operators.python_operator import PythonOperator
|
|||
from airflow.operators.python_operator import ShortCircuitOperator
|
||||
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
|
||||
from airflow.utils import timezone
|
||||
from airflow.utils.weight_rule import WeightRule
|
||||
from airflow.utils.state import State
|
||||
from airflow.utils.trigger_rule import TriggerRule
|
||||
from mock import patch
|
||||
|
@ -201,6 +204,96 @@ class DagTest(unittest.TestCase):
|
|||
|
||||
self.assertEquals(tuple(), dag.topological_sort())
|
||||
|
||||
def test_dag_task_priority_weight_total(self):
|
||||
width = 5
|
||||
depth = 5
|
||||
weight = 5
|
||||
pattern = re.compile('stage(\\d*).(\\d*)')
|
||||
# Fully connected parallel tasks. i.e. every task at each parallel
|
||||
# stage is dependent on every task in the previous stage.
|
||||
# Default weight should be calculated using downstream descendants
|
||||
with DAG('dag', start_date=DEFAULT_DATE,
|
||||
default_args={'owner': 'owner1'}) as dag:
|
||||
pipeline = [
|
||||
[DummyOperator(
|
||||
task_id='stage{}.{}'.format(i, j), priority_weight=weight)
|
||||
for j in range(0, width)] for i in range(0, depth)
|
||||
]
|
||||
for d, stage in enumerate(pipeline):
|
||||
if d == 0:
|
||||
continue
|
||||
for current_task in stage:
|
||||
for prev_task in pipeline[d - 1]:
|
||||
current_task.set_upstream(prev_task)
|
||||
|
||||
for task in six.itervalues(dag.task_dict):
|
||||
match = pattern.match(task.task_id)
|
||||
task_depth = int(match.group(1))
|
||||
# the sum of each stages after this task + itself
|
||||
correct_weight = ((depth - (task_depth + 1)) * width + 1) * weight
|
||||
|
||||
calculated_weight = task.priority_weight_total
|
||||
self.assertEquals(calculated_weight, correct_weight)
|
||||
|
||||
# Same test as above except use 'upstream' for weight calculation
|
||||
weight = 3
|
||||
with DAG('dag', start_date=DEFAULT_DATE,
|
||||
default_args={'owner': 'owner1'}) as dag:
|
||||
pipeline = [
|
||||
[DummyOperator(
|
||||
task_id='stage{}.{}'.format(i, j), priority_weight=weight,
|
||||
weight_rule=WeightRule.UPSTREAM)
|
||||
for j in range(0, width)] for i in range(0, depth)
|
||||
]
|
||||
for d, stage in enumerate(pipeline):
|
||||
if d == 0:
|
||||
continue
|
||||
for current_task in stage:
|
||||
for prev_task in pipeline[d - 1]:
|
||||
current_task.set_upstream(prev_task)
|
||||
|
||||
for task in six.itervalues(dag.task_dict):
|
||||
match = pattern.match(task.task_id)
|
||||
task_depth = int(match.group(1))
|
||||
# the sum of each stages after this task + itself
|
||||
correct_weight = ((task_depth) * width + 1) * weight
|
||||
|
||||
calculated_weight = task.priority_weight_total
|
||||
self.assertEquals(calculated_weight, correct_weight)
|
||||
|
||||
# Same test as above except use 'absolute' for weight calculation
|
||||
weight = 10
|
||||
with DAG('dag', start_date=DEFAULT_DATE,
|
||||
default_args={'owner': 'owner1'}) as dag:
|
||||
pipeline = [
|
||||
[DummyOperator(
|
||||
task_id='stage{}.{}'.format(i, j), priority_weight=weight,
|
||||
weight_rule=WeightRule.ABSOLUTE)
|
||||
for j in range(0, width)] for i in range(0, depth)
|
||||
]
|
||||
for d, stage in enumerate(pipeline):
|
||||
if d == 0:
|
||||
continue
|
||||
for current_task in stage:
|
||||
for prev_task in pipeline[d - 1]:
|
||||
current_task.set_upstream(prev_task)
|
||||
|
||||
for task in six.itervalues(dag.task_dict):
|
||||
match = pattern.match(task.task_id)
|
||||
task_depth = int(match.group(1))
|
||||
# the sum of each stages after this task + itself
|
||||
correct_weight = weight
|
||||
|
||||
calculated_weight = task.priority_weight_total
|
||||
self.assertEquals(calculated_weight, correct_weight)
|
||||
|
||||
# Test if we enter an invalid weight rule
|
||||
with DAG('dag', start_date=DEFAULT_DATE,
|
||||
default_args={'owner': 'owner1'}) as dag:
|
||||
with self.assertRaises(AirflowException):
|
||||
DummyOperator(task_id='should_fail', weight_rule='no rule')
|
||||
|
||||
|
||||
def test_get_num_task_instances(self):
|
||||
test_dag_id = 'test_get_num_task_instances_dag'
|
||||
test_task_id = 'task_1'
|
||||
|
|
Загрузка…
Ссылка в новой задаче