[AIRFLOW-4057] statsd should handle invalid characters (#4889)
This commit is contained in:
Родитель
dd6e8bc49e
Коммит
dce353957b
|
@ -58,6 +58,10 @@ class AirflowRescheduleException(AirflowException):
|
|||
self.reschedule_date = reschedule_date
|
||||
|
||||
|
||||
class InvalidStatsNameException(AirflowException):
|
||||
pass
|
||||
|
||||
|
||||
class AirflowTaskTimeout(AirflowException):
|
||||
pass
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from collections import OrderedDict
|
|||
# To avoid circular imports
|
||||
import airflow.utils.dag_processing
|
||||
from airflow import configuration
|
||||
from airflow.settings import Stats
|
||||
from airflow.stats import Stats
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
from airflow.utils.state import State
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ from airflow.exceptions import AirflowException
|
|||
from airflow.models import DAG, DagRun, errors
|
||||
from airflow.models.dagpickle import DagPickle
|
||||
from airflow.models.slamiss import SlaMiss
|
||||
from airflow.settings import Stats
|
||||
from airflow.stats import Stats
|
||||
from airflow.task.task_runner import get_task_runner
|
||||
from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
|
||||
from airflow.utils import asciiart, helpers, timezone
|
||||
|
|
|
@ -90,6 +90,7 @@ from airflow.models.log import Log
|
|||
from airflow.models.taskfail import TaskFail
|
||||
from airflow.models.taskreschedule import TaskReschedule
|
||||
from airflow.models.xcom import XCom
|
||||
from airflow.stats import Stats
|
||||
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
|
||||
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
|
||||
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
|
||||
|
@ -115,8 +116,6 @@ install_aliases()
|
|||
|
||||
XCOM_RETURN_KEY = 'return_value'
|
||||
|
||||
Stats = settings.Stats
|
||||
|
||||
|
||||
class InvalidFernetToken(Exception):
|
||||
# If Fernet isn't loaded we need a valid exception class to catch. If it is
|
||||
|
|
|
@ -56,6 +56,13 @@ class AirflowPlugin(object):
|
|||
appbuilder_views = [] # type: List[Any]
|
||||
appbuilder_menu_items = [] # type: List[Any]
|
||||
|
||||
# A function that validate the statsd stat name, apply changes
|
||||
# to the stat name if necessary and return the transformed stat name.
|
||||
#
|
||||
# The function should have the following signature:
|
||||
# def func_name(stat_name: str) -> str:
|
||||
stat_name_handler = None # type:Any
|
||||
|
||||
@classmethod
|
||||
def validate(cls):
|
||||
if not cls.name:
|
||||
|
@ -182,7 +189,9 @@ flask_blueprints = [] # type: List[Any]
|
|||
menu_links = [] # type: List[Any]
|
||||
flask_appbuilder_views = [] # type: List[Any]
|
||||
flask_appbuilder_menu_links = [] # type: List[Any]
|
||||
stat_name_handler = None # type: Any
|
||||
|
||||
stat_name_handlers = []
|
||||
for p in plugins:
|
||||
operators_modules.append(
|
||||
make_module('airflow.operators.' + p.name, p.operators + p.sensors))
|
||||
|
@ -202,3 +211,12 @@ for p in plugins:
|
|||
'name': p.name,
|
||||
'blueprint': bp
|
||||
} for bp in p.flask_blueprints])
|
||||
if p.stat_name_handler:
|
||||
stat_name_handlers.append(p.stat_name_handler)
|
||||
|
||||
if len(stat_name_handlers) > 1:
|
||||
raise AirflowPluginException(
|
||||
'Specified more than one stat_name_handler ({}) '
|
||||
'is not allowed.'.format(stat_name_handlers))
|
||||
|
||||
stat_name_handler = stat_name_handlers[0] if len(stat_name_handlers) == 1 else None
|
||||
|
|
|
@ -26,7 +26,7 @@ import atexit
|
|||
import logging
|
||||
import os
|
||||
import pendulum
|
||||
import socket
|
||||
|
||||
|
||||
from sqlalchemy import create_engine, exc
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
|
@ -51,38 +51,6 @@ except Exception:
|
|||
log.info("Configured default timezone %s" % TIMEZONE)
|
||||
|
||||
|
||||
class DummyStatsLogger(object):
|
||||
@classmethod
|
||||
def incr(cls, stat, count=1, rate=1):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def decr(cls, stat, count=1, rate=1):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def gauge(cls, stat, value, rate=1, delta=False):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def timing(cls, stat, dt):
|
||||
pass
|
||||
|
||||
|
||||
Stats = DummyStatsLogger
|
||||
|
||||
try:
|
||||
if conf.getboolean('scheduler', 'statsd_on'):
|
||||
from statsd import StatsClient
|
||||
|
||||
statsd = StatsClient(
|
||||
host=conf.get('scheduler', 'statsd_host'),
|
||||
port=conf.getint('scheduler', 'statsd_port'),
|
||||
prefix=conf.get('scheduler', 'statsd_prefix'))
|
||||
Stats = statsd
|
||||
except (socket.gaierror, ImportError) as e:
|
||||
log.warning("Could not configure StatsClient: %s, using DummyStatsLogger instead.", e)
|
||||
|
||||
HEADER = '\n'.join([
|
||||
r' ____________ _____________',
|
||||
r' ____ |__( )_________ __/__ /________ __',
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
from functools import wraps
|
||||
import logging
|
||||
from six import string_types
|
||||
import socket
|
||||
import string
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
from airflow import configuration as conf
|
||||
from airflow.exceptions import InvalidStatsNameException
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DummyStatsLogger(object):
|
||||
@classmethod
|
||||
def incr(cls, stat, count=1, rate=1):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def decr(cls, stat, count=1, rate=1):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def gauge(cls, stat, value, rate=1, delta=False):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def timing(cls, stat, dt):
|
||||
pass
|
||||
|
||||
|
||||
# Only characters in the character set are considered valid
|
||||
# for the stat_name if stat_name_default_handler is used.
|
||||
ALLOWED_CHARACTERS = set(string.ascii_letters + string.digits + '_.-')
|
||||
|
||||
|
||||
def stat_name_default_handler(stat_name, max_length=250):
|
||||
if not isinstance(stat_name, string_types):
|
||||
raise InvalidStatsNameException('The stat_name has to be a string')
|
||||
if len(stat_name) > max_length:
|
||||
raise InvalidStatsNameException(textwrap.dedent("""\
|
||||
The stat_name ({stat_name}) has to be less than {max_length} characters.
|
||||
""".format(stat_name=stat_name, max_length=max_length)))
|
||||
if not all((c in ALLOWED_CHARACTERS) for c in stat_name):
|
||||
raise InvalidStatsNameException(textwrap.dedent("""\
|
||||
The stat name ({stat_name}) has to be composed with characters in
|
||||
{allowed_characters}.
|
||||
""".format(stat_name=stat_name,
|
||||
allowed_characters=ALLOWED_CHARACTERS)))
|
||||
return stat_name
|
||||
|
||||
|
||||
def validate_stat(f):
|
||||
@wraps(f)
|
||||
def wrapper(stat, *args, **kwargs):
|
||||
try:
|
||||
from airflow.plugins_manager import stat_name_handler
|
||||
if stat_name_handler:
|
||||
handle_stat_name_func = stat_name_handler
|
||||
else:
|
||||
handle_stat_name_func = stat_name_default_handler
|
||||
stat_name = handle_stat_name_func(stat)
|
||||
except Exception as err:
|
||||
log.warning('Invalid stat name: {stat}.'.format(stat=stat), err)
|
||||
return
|
||||
return f(stat_name, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class SafeStatsdLogger(object):
|
||||
|
||||
def __init__(self, statsd_client):
|
||||
self.statsd = statsd_client
|
||||
|
||||
@validate_stat
|
||||
def incr(self, stat, count=1, rate=1):
|
||||
return self.statsd.incr(stat, count, rate)
|
||||
|
||||
@validate_stat
|
||||
def decr(self, stat, count=1, rate=1):
|
||||
return self.statsd.decr(stat, count, rate)
|
||||
|
||||
@validate_stat
|
||||
def gauge(self, stat, value, rate=1, delta=False):
|
||||
return self.statsd.gauge(stat, value, rate, delta)
|
||||
|
||||
@validate_stat
|
||||
def timing(self, stat, dt):
|
||||
return self.statsd.timing(stat, dt)
|
||||
|
||||
|
||||
Stats = DummyStatsLogger # type: Any
|
||||
|
||||
try:
|
||||
if conf.getboolean('scheduler', 'statsd_on'):
|
||||
from statsd import StatsClient
|
||||
|
||||
statsd = StatsClient(
|
||||
host=conf.get('scheduler', 'statsd_host'),
|
||||
port=conf.getint('scheduler', 'statsd_port'),
|
||||
prefix=conf.get('scheduler', 'statsd_prefix'))
|
||||
Stats = SafeStatsdLogger(statsd)
|
||||
except (socket.gaierror, ImportError) as e:
|
||||
log.warning("Could not configure StatsClient: %s, using DummyStatsLogger instead.", e)
|
|
@ -48,7 +48,8 @@ from airflow import configuration as conf
|
|||
from airflow.dag.base_dag import BaseDag, BaseDagBag
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import errors
|
||||
from airflow.settings import logging_class_path, Stats
|
||||
from airflow.settings import logging_class_path
|
||||
from airflow.stats import Stats
|
||||
from airflow.utils import timezone
|
||||
from airflow.utils.db import provide_session
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
|
|
|
@ -92,6 +92,12 @@ looks like:
|
|||
appbuilder_views = []
|
||||
# A list of dictionaries containing FlaskAppBuilder BaseView object and some metadata. See example below
|
||||
appbuilder_menu_items = []
|
||||
# A function that validate the statsd stat name, apply changes to the stat name if necessary and
|
||||
# return the transformed stat name.
|
||||
#
|
||||
# The function should have the following signature:
|
||||
# def func_name(stat_name: str) -> str:
|
||||
stat_name_handler = None
|
||||
# A callback to perform actions when airflow starts and the plugin is loaded.
|
||||
# NOTE: Ensure your plugin has *args, and **kwargs in the method definition
|
||||
# to protect against extra parameters injected into the on_load(...)
|
||||
|
@ -191,6 +197,10 @@ definitions in Airflow.
|
|||
"category_icon": "fa-th",
|
||||
"href": "https://www.google.com"}
|
||||
|
||||
# Validate the statsd stat name
|
||||
def stat_name_dummy_handler(stat_name):
|
||||
return stat_name
|
||||
|
||||
# Defining the plugin class
|
||||
class AirflowTestPlugin(AirflowPlugin):
|
||||
name = "test_plugin"
|
||||
|
@ -202,6 +212,7 @@ definitions in Airflow.
|
|||
flask_blueprints = [bp]
|
||||
appbuilder_views = [v_appbuilder_package]
|
||||
appbuilder_menu_items = [appbuilder_mitem]
|
||||
stat_name_handler = stat_name_dummy_handler
|
||||
|
||||
|
||||
Note on role based views
|
||||
|
|
|
@ -75,7 +75,6 @@ appbuilder_mitem = {"name": "Google",
|
|||
"category_icon": "fa-th",
|
||||
"href": "https://www.google.com"}
|
||||
|
||||
|
||||
# Creating a flask blueprint to intergrate the templates and static folder
|
||||
bp = Blueprint(
|
||||
"test_plugin", __name__,
|
||||
|
@ -84,6 +83,11 @@ bp = Blueprint(
|
|||
static_url_path='/static/test_plugin')
|
||||
|
||||
|
||||
# Create a handler to validate statsd stat name
|
||||
def stat_name_dummy_handler(stat_name):
|
||||
return stat_name
|
||||
|
||||
|
||||
# Defining the plugin class
|
||||
class AirflowTestPlugin(AirflowPlugin):
|
||||
name = "test_plugin"
|
||||
|
@ -95,6 +99,7 @@ class AirflowTestPlugin(AirflowPlugin):
|
|||
flask_blueprints = [bp]
|
||||
appbuilder_views = [v_appbuilder_package]
|
||||
appbuilder_menu_items = [appbuilder_mitem]
|
||||
stat_name_handler = stat_name_dummy_handler
|
||||
|
||||
|
||||
class MockPluginA(AirflowPlugin):
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from airflow.exceptions import InvalidStatsNameException
|
||||
from airflow.stats import stat_name_default_handler
|
||||
|
||||
|
||||
class TestStats(unittest.TestCase):
|
||||
|
||||
def test_stat_name_default_handler_success(self):
|
||||
stat_name = 'task_run'
|
||||
stat_name_ = stat_name_default_handler(stat_name)
|
||||
self.assertEqual(stat_name, stat_name_)
|
||||
|
||||
def test_stat_name_default_handler_not_string(self):
|
||||
try:
|
||||
stat_name_default_handler(list())
|
||||
except InvalidStatsNameException:
|
||||
return
|
||||
self.fail()
|
||||
|
||||
def test_stat_name_default_handler_exceed_max_length(self):
|
||||
try:
|
||||
stat_name_default_handler('123456', 3)
|
||||
except InvalidStatsNameException:
|
||||
return
|
||||
self.fail()
|
||||
|
||||
def test_stat_name_default_handler_invalid_character(self):
|
||||
try:
|
||||
stat_name_default_handler(':123456')
|
||||
except InvalidStatsNameException:
|
||||
return
|
||||
self.fail()
|
Загрузка…
Ссылка в новой задаче