[AIRFLOW-4057] statsd should handle invalid characters (#4889)

This commit is contained in:
Chao-Han Tsai 2019-03-21 21:13:29 -07:00 коммит произвёл Tao Feng
Родитель dd6e8bc49e
Коммит dce353957b
11 изменённых файлов: 223 добавлений и 39 удалений

Просмотреть файл

@ -58,6 +58,10 @@ class AirflowRescheduleException(AirflowException):
self.reschedule_date = reschedule_date
class InvalidStatsNameException(AirflowException):
pass
class AirflowTaskTimeout(AirflowException):
pass

Просмотреть файл

@ -23,7 +23,7 @@ from collections import OrderedDict
# To avoid circular imports
import airflow.utils.dag_processing
from airflow import configuration
from airflow.settings import Stats
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State

Просмотреть файл

@ -46,7 +46,7 @@ from airflow.exceptions import AirflowException
from airflow.models import DAG, DagRun, errors
from airflow.models.dagpickle import DagPickle
from airflow.models.slamiss import SlaMiss
from airflow.settings import Stats
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
from airflow.utils import asciiart, helpers, timezone

Просмотреть файл

@ -90,6 +90,7 @@ from airflow.models.log import Log
from airflow.models.taskfail import TaskFail
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import XCom
from airflow.stats import Stats
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
@ -115,8 +116,6 @@ install_aliases()
XCOM_RETURN_KEY = 'return_value'
Stats = settings.Stats
class InvalidFernetToken(Exception):
# If Fernet isn't loaded we need a valid exception class to catch. If it is

Просмотреть файл

@ -56,6 +56,13 @@ class AirflowPlugin(object):
appbuilder_views = [] # type: List[Any]
appbuilder_menu_items = [] # type: List[Any]
# A function that validate the statsd stat name, apply changes
# to the stat name if necessary and return the transformed stat name.
#
# The function should have the following signature:
# def func_name(stat_name: str) -> str:
stat_name_handler = None # type:Any
@classmethod
def validate(cls):
if not cls.name:
@ -182,7 +189,9 @@ flask_blueprints = [] # type: List[Any]
menu_links = [] # type: List[Any]
flask_appbuilder_views = [] # type: List[Any]
flask_appbuilder_menu_links = [] # type: List[Any]
stat_name_handler = None # type: Any
stat_name_handlers = []
for p in plugins:
operators_modules.append(
make_module('airflow.operators.' + p.name, p.operators + p.sensors))
@ -202,3 +211,12 @@ for p in plugins:
'name': p.name,
'blueprint': bp
} for bp in p.flask_blueprints])
if p.stat_name_handler:
stat_name_handlers.append(p.stat_name_handler)
if len(stat_name_handlers) > 1:
raise AirflowPluginException(
'Specified more than one stat_name_handler ({}) '
'is not allowed.'.format(stat_name_handlers))
stat_name_handler = stat_name_handlers[0] if len(stat_name_handlers) == 1 else None

Просмотреть файл

@ -26,7 +26,7 @@ import atexit
import logging
import os
import pendulum
import socket
from sqlalchemy import create_engine, exc
from sqlalchemy.orm import scoped_session, sessionmaker
@ -51,38 +51,6 @@ except Exception:
log.info("Configured default timezone %s" % TIMEZONE)
class DummyStatsLogger(object):
@classmethod
def incr(cls, stat, count=1, rate=1):
pass
@classmethod
def decr(cls, stat, count=1, rate=1):
pass
@classmethod
def gauge(cls, stat, value, rate=1, delta=False):
pass
@classmethod
def timing(cls, stat, dt):
pass
Stats = DummyStatsLogger
try:
if conf.getboolean('scheduler', 'statsd_on'):
from statsd import StatsClient
statsd = StatsClient(
host=conf.get('scheduler', 'statsd_host'),
port=conf.getint('scheduler', 'statsd_port'),
prefix=conf.get('scheduler', 'statsd_prefix'))
Stats = statsd
except (socket.gaierror, ImportError) as e:
log.warning("Could not configure StatsClient: %s, using DummyStatsLogger instead.", e)
HEADER = '\n'.join([
r' ____________ _____________',
r' ____ |__( )_________ __/__ /________ __',

126
airflow/stats.py Normal file
Просмотреть файл

@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from functools import wraps
import logging
from six import string_types
import socket
import string
import textwrap
from typing import Any
from airflow import configuration as conf
from airflow.exceptions import InvalidStatsNameException
log = logging.getLogger(__name__)
class DummyStatsLogger(object):
@classmethod
def incr(cls, stat, count=1, rate=1):
pass
@classmethod
def decr(cls, stat, count=1, rate=1):
pass
@classmethod
def gauge(cls, stat, value, rate=1, delta=False):
pass
@classmethod
def timing(cls, stat, dt):
pass
# Only characters in the character set are considered valid
# for the stat_name if stat_name_default_handler is used.
ALLOWED_CHARACTERS = set(string.ascii_letters + string.digits + '_.-')
def stat_name_default_handler(stat_name, max_length=250):
if not isinstance(stat_name, string_types):
raise InvalidStatsNameException('The stat_name has to be a string')
if len(stat_name) > max_length:
raise InvalidStatsNameException(textwrap.dedent("""\
The stat_name ({stat_name}) has to be less than {max_length} characters.
""".format(stat_name=stat_name, max_length=max_length)))
if not all((c in ALLOWED_CHARACTERS) for c in stat_name):
raise InvalidStatsNameException(textwrap.dedent("""\
The stat name ({stat_name}) has to be composed with characters in
{allowed_characters}.
""".format(stat_name=stat_name,
allowed_characters=ALLOWED_CHARACTERS)))
return stat_name
def validate_stat(f):
@wraps(f)
def wrapper(stat, *args, **kwargs):
try:
from airflow.plugins_manager import stat_name_handler
if stat_name_handler:
handle_stat_name_func = stat_name_handler
else:
handle_stat_name_func = stat_name_default_handler
stat_name = handle_stat_name_func(stat)
except Exception as err:
log.warning('Invalid stat name: {stat}.'.format(stat=stat), err)
return
return f(stat_name, *args, **kwargs)
return wrapper
class SafeStatsdLogger(object):
def __init__(self, statsd_client):
self.statsd = statsd_client
@validate_stat
def incr(self, stat, count=1, rate=1):
return self.statsd.incr(stat, count, rate)
@validate_stat
def decr(self, stat, count=1, rate=1):
return self.statsd.decr(stat, count, rate)
@validate_stat
def gauge(self, stat, value, rate=1, delta=False):
return self.statsd.gauge(stat, value, rate, delta)
@validate_stat
def timing(self, stat, dt):
return self.statsd.timing(stat, dt)
Stats = DummyStatsLogger # type: Any
try:
if conf.getboolean('scheduler', 'statsd_on'):
from statsd import StatsClient
statsd = StatsClient(
host=conf.get('scheduler', 'statsd_host'),
port=conf.getint('scheduler', 'statsd_port'),
prefix=conf.get('scheduler', 'statsd_prefix'))
Stats = SafeStatsdLogger(statsd)
except (socket.gaierror, ImportError) as e:
log.warning("Could not configure StatsClient: %s, using DummyStatsLogger instead.", e)

Просмотреть файл

@ -48,7 +48,8 @@ from airflow import configuration as conf
from airflow.dag.base_dag import BaseDag, BaseDagBag
from airflow.exceptions import AirflowException
from airflow.models import errors
from airflow.settings import logging_class_path, Stats
from airflow.settings import logging_class_path
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.db import provide_session
from airflow.utils.log.logging_mixin import LoggingMixin

Просмотреть файл

@ -92,6 +92,12 @@ looks like:
appbuilder_views = []
# A list of dictionaries containing FlaskAppBuilder BaseView object and some metadata. See example below
appbuilder_menu_items = []
# A function that validate the statsd stat name, apply changes to the stat name if necessary and
# return the transformed stat name.
#
# The function should have the following signature:
# def func_name(stat_name: str) -> str:
stat_name_handler = None
# A callback to perform actions when airflow starts and the plugin is loaded.
# NOTE: Ensure your plugin has *args, and **kwargs in the method definition
# to protect against extra parameters injected into the on_load(...)
@ -191,6 +197,10 @@ definitions in Airflow.
"category_icon": "fa-th",
"href": "https://www.google.com"}
# Validate the statsd stat name
def stat_name_dummy_handler(stat_name):
return stat_name
# Defining the plugin class
class AirflowTestPlugin(AirflowPlugin):
name = "test_plugin"
@ -202,6 +212,7 @@ definitions in Airflow.
flask_blueprints = [bp]
appbuilder_views = [v_appbuilder_package]
appbuilder_menu_items = [appbuilder_mitem]
stat_name_handler = stat_name_dummy_handler
Note on role based views

Просмотреть файл

@ -75,7 +75,6 @@ appbuilder_mitem = {"name": "Google",
"category_icon": "fa-th",
"href": "https://www.google.com"}
# Creating a flask blueprint to intergrate the templates and static folder
bp = Blueprint(
"test_plugin", __name__,
@ -84,6 +83,11 @@ bp = Blueprint(
static_url_path='/static/test_plugin')
# Create a handler to validate statsd stat name
def stat_name_dummy_handler(stat_name):
return stat_name
# Defining the plugin class
class AirflowTestPlugin(AirflowPlugin):
name = "test_plugin"
@ -95,6 +99,7 @@ class AirflowTestPlugin(AirflowPlugin):
flask_blueprints = [bp]
appbuilder_views = [v_appbuilder_package]
appbuilder_menu_items = [appbuilder_mitem]
stat_name_handler = stat_name_dummy_handler
class MockPluginA(AirflowPlugin):

52
tests/test_stats.py Normal file
Просмотреть файл

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from airflow.exceptions import InvalidStatsNameException
from airflow.stats import stat_name_default_handler
class TestStats(unittest.TestCase):
def test_stat_name_default_handler_success(self):
stat_name = 'task_run'
stat_name_ = stat_name_default_handler(stat_name)
self.assertEqual(stat_name, stat_name_)
def test_stat_name_default_handler_not_string(self):
try:
stat_name_default_handler(list())
except InvalidStatsNameException:
return
self.fail()
def test_stat_name_default_handler_exceed_max_length(self):
try:
stat_name_default_handler('123456', 3)
except InvalidStatsNameException:
return
self.fail()
def test_stat_name_default_handler_invalid_character(self):
try:
stat_name_default_handler(':123456')
except InvalidStatsNameException:
return
self.fail()