[AIRFLOW-2203] Cache signature in apply_defaults

Cache inspect.signature for the wrapper closure to avoid calling it at
every decorated invocation. This is separate sig_cache created per
decoration, i.e. each function decorated using apply_defaults will have
a different sig_cache.
This commit is contained in:
wongwill86 2018-03-12 17:08:44 -04:00 коммит произвёл Fokko Driesprong
Родитель 92357d53e6
Коммит 81ec595b6c
2 изменённых файлов: 88 добавлений и 12 удалений

Просмотреть файл

@ -39,6 +39,19 @@ def apply_defaults(func):
inheritance and argument defaults, this decorator also alerts with
specific information about the missing arguments.
"""
import airflow.models
# Cache inspect.signature for the wrapper closure to avoid calling it
# at every decorated invocation. This is separate sig_cache created
# per decoration, i.e. each function decorated using apply_defaults will
# have a different sig_cache.
sig_cache = signature(func)
non_optional_args = {
name for (name, param) in sig_cache.parameters.items()
if param.default == param.empty and
param.name != 'self' and
param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)}
@wraps(func)
def wrapper(*args, **kwargs):
if len(args) > 1:
@ -46,9 +59,9 @@ def apply_defaults(func):
"Use keyword arguments when initializing operators")
dag_args = {}
dag_params = {}
import airflow.models
if kwargs.get('dag', None) or airflow.models._CONTEXT_MANAGER_DAG:
dag = kwargs.get('dag', None) or airflow.models._CONTEXT_MANAGER_DAG
dag = kwargs.get('dag', None) or airflow.models._CONTEXT_MANAGER_DAG
if dag:
dag_args = copy(dag.default_args) or {}
dag_params = copy(dag.params) or {}
@ -67,16 +80,10 @@ def apply_defaults(func):
dag_args.update(default_args)
default_args = dag_args
sig = signature(func)
non_optional_args = [
name for (name, param) in sig.parameters.items()
if param.default == param.empty and
param.name != 'self' and
param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)]
for arg in sig.parameters:
if arg in default_args and arg not in kwargs:
for arg in sig_cache.parameters:
if arg not in kwargs and arg in default_args:
kwargs[arg] = default_args[arg]
missing_args = list(set(non_optional_args) - set(kwargs))
missing_args = list(non_optional_args - set(kwargs))
if missing_args:
msg = "Argument {0} is required".format(missing_args)
raise AirflowException(msg)

Просмотреть файл

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException
# Essentially similar to airflow.models.BaseOperator
class DummyClass(object):
@apply_defaults
def __init__(self, test_param, params=None, default_args=None):
self.test_param = test_param
class DummySubClass(DummyClass):
@apply_defaults
def __init__(self, test_sub_param, *args, **kwargs):
super(DummySubClass, self).__init__(*args, **kwargs)
self.test_sub_param = test_sub_param
class ApplyDefaultTest(unittest.TestCase):
def test_apply(self):
dc = DummyClass(test_param=True)
self.assertTrue(dc.test_param)
with self.assertRaisesRegexp(AirflowException, 'Argument.*test_param.*required'):
DummySubClass(test_sub_param=True)
def test_default_args(self):
default_args = {'test_param': True}
dc = DummyClass(default_args=default_args)
self.assertTrue(dc.test_param)
default_args = {'test_param': True, 'test_sub_param': True}
dsc = DummySubClass(default_args=default_args)
self.assertTrue(dc.test_param)
self.assertTrue(dsc.test_sub_param)
default_args = {'test_param': True}
dsc = DummySubClass(default_args=default_args, test_sub_param=True)
self.assertTrue(dc.test_param)
self.assertTrue(dsc.test_sub_param)
with self.assertRaisesRegexp(AirflowException,
'Argument.*test_sub_param.*required'):
DummySubClass(default_args=default_args)
def test_incorrect_default_args(self):
default_args = {'test_param': True, 'extra_param': True}
dc = DummyClass(default_args=default_args)
self.assertTrue(dc.test_param)
default_args = {'random_params': True}
with self.assertRaisesRegexp(AirflowException, 'Argument.*test_param.*required'):
DummyClass(default_args=default_args)