[AIRFLOW-2203] Cache signature in apply_defaults
Cache inspect.signature for the wrapper closure to avoid calling it at every decorated invocation. This is separate sig_cache created per decoration, i.e. each function decorated using apply_defaults will have a different sig_cache.
This commit is contained in:
Родитель
92357d53e6
Коммит
81ec595b6c
|
@ -39,6 +39,19 @@ def apply_defaults(func):
|
|||
inheritance and argument defaults, this decorator also alerts with
|
||||
specific information about the missing arguments.
|
||||
"""
|
||||
|
||||
import airflow.models
|
||||
# Cache inspect.signature for the wrapper closure to avoid calling it
|
||||
# at every decorated invocation. This is separate sig_cache created
|
||||
# per decoration, i.e. each function decorated using apply_defaults will
|
||||
# have a different sig_cache.
|
||||
sig_cache = signature(func)
|
||||
non_optional_args = {
|
||||
name for (name, param) in sig_cache.parameters.items()
|
||||
if param.default == param.empty and
|
||||
param.name != 'self' and
|
||||
param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)}
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if len(args) > 1:
|
||||
|
@ -46,9 +59,9 @@ def apply_defaults(func):
|
|||
"Use keyword arguments when initializing operators")
|
||||
dag_args = {}
|
||||
dag_params = {}
|
||||
import airflow.models
|
||||
if kwargs.get('dag', None) or airflow.models._CONTEXT_MANAGER_DAG:
|
||||
dag = kwargs.get('dag', None) or airflow.models._CONTEXT_MANAGER_DAG
|
||||
|
||||
dag = kwargs.get('dag', None) or airflow.models._CONTEXT_MANAGER_DAG
|
||||
if dag:
|
||||
dag_args = copy(dag.default_args) or {}
|
||||
dag_params = copy(dag.params) or {}
|
||||
|
||||
|
@ -67,16 +80,10 @@ def apply_defaults(func):
|
|||
dag_args.update(default_args)
|
||||
default_args = dag_args
|
||||
|
||||
sig = signature(func)
|
||||
non_optional_args = [
|
||||
name for (name, param) in sig.parameters.items()
|
||||
if param.default == param.empty and
|
||||
param.name != 'self' and
|
||||
param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)]
|
||||
for arg in sig.parameters:
|
||||
if arg in default_args and arg not in kwargs:
|
||||
for arg in sig_cache.parameters:
|
||||
if arg not in kwargs and arg in default_args:
|
||||
kwargs[arg] = default_args[arg]
|
||||
missing_args = list(set(non_optional_args) - set(kwargs))
|
||||
missing_args = list(non_optional_args - set(kwargs))
|
||||
if missing_args:
|
||||
msg = "Argument {0} is required".format(missing_args)
|
||||
raise AirflowException(msg)
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
from airflow.exceptions import AirflowException
|
||||
|
||||
|
||||
# Essentially similar to airflow.models.BaseOperator
|
||||
class DummyClass(object):
|
||||
@apply_defaults
|
||||
def __init__(self, test_param, params=None, default_args=None):
|
||||
self.test_param = test_param
|
||||
|
||||
|
||||
class DummySubClass(DummyClass):
|
||||
@apply_defaults
|
||||
def __init__(self, test_sub_param, *args, **kwargs):
|
||||
super(DummySubClass, self).__init__(*args, **kwargs)
|
||||
self.test_sub_param = test_sub_param
|
||||
|
||||
|
||||
class ApplyDefaultTest(unittest.TestCase):
|
||||
|
||||
def test_apply(self):
|
||||
dc = DummyClass(test_param=True)
|
||||
self.assertTrue(dc.test_param)
|
||||
|
||||
with self.assertRaisesRegexp(AirflowException, 'Argument.*test_param.*required'):
|
||||
DummySubClass(test_sub_param=True)
|
||||
|
||||
def test_default_args(self):
|
||||
default_args = {'test_param': True}
|
||||
dc = DummyClass(default_args=default_args)
|
||||
self.assertTrue(dc.test_param)
|
||||
|
||||
default_args = {'test_param': True, 'test_sub_param': True}
|
||||
dsc = DummySubClass(default_args=default_args)
|
||||
self.assertTrue(dc.test_param)
|
||||
self.assertTrue(dsc.test_sub_param)
|
||||
|
||||
default_args = {'test_param': True}
|
||||
dsc = DummySubClass(default_args=default_args, test_sub_param=True)
|
||||
self.assertTrue(dc.test_param)
|
||||
self.assertTrue(dsc.test_sub_param)
|
||||
|
||||
with self.assertRaisesRegexp(AirflowException,
|
||||
'Argument.*test_sub_param.*required'):
|
||||
DummySubClass(default_args=default_args)
|
||||
|
||||
def test_incorrect_default_args(self):
|
||||
default_args = {'test_param': True, 'extra_param': True}
|
||||
dc = DummyClass(default_args=default_args)
|
||||
self.assertTrue(dc.test_param)
|
||||
|
||||
default_args = {'random_params': True}
|
||||
with self.assertRaisesRegexp(AirflowException, 'Argument.*test_param.*required'):
|
||||
DummyClass(default_args=default_args)
|
Загрузка…
Ссылка в новой задаче