[AIRFLOW-6793] Respect env variable in airflow config command (#7413)

This commit is contained in:
Kamil Breguła 2020-02-18 11:59:29 +01:00 коммит произвёл GitHub
Родитель 3f25ff93cb
Коммит 7c95c8144d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 48 добавлений и 13 удалений

Просмотреть файл

@ -31,9 +31,7 @@ log = logging.getLogger(__name__)
broker_url = conf.get('celery', 'BROKER_URL')
broker_transport_options = conf.getsection(
'celery_broker_transport_options'
)
broker_transport_options = conf.getsection('celery_broker_transport_options') or {}
if 'visibility_timeout' not in broker_transport_options:
if _broker_supports_visibility_timeout(broker_url):
broker_transport_options['visibility_timeout'] = 21600

Просмотреть файл

@ -28,7 +28,7 @@ from base64 import b64encode
from collections import OrderedDict
# Ignored Mypy on configparser because it thinks the configparser module has no _UNSET attribute
from configparser import _UNSET, ConfigParser, NoOptionError, NoSectionError # type: ignore
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple, Union
import yaml
from cryptography.fernet import Fernet
@ -335,7 +335,7 @@ class AirflowConfigParser(ConfigParser):
if self.airflow_defaults.has_option(section, option) and remove_default:
self.airflow_defaults.remove_option(section, option)
def getsection(self, section):
def getsection(self, section: str) -> Optional[Dict[str, Union[str, int, float, bool]]]:
"""
Returns the section as a dict. Values are converted to int, float, bool
as required.
@ -343,22 +343,24 @@ class AirflowConfigParser(ConfigParser):
:param section: section from the config
:rtype: dict
"""
if (section not in self._sections and
section not in self.airflow_defaults._sections):
if (section not in self._sections and section not in self.airflow_defaults._sections): # type: ignore
return None
_section = copy.deepcopy(self.airflow_defaults._sections[section])
_section = copy.deepcopy(self.airflow_defaults._sections[section]) # type: ignore
if section in self._sections:
_section.update(copy.deepcopy(self._sections[section]))
if section in self._sections: # type: ignore
_section.update(copy.deepcopy(self._sections[section])) # type: ignore
section_prefix = 'AIRFLOW__{S}__'.format(S=section.upper())
for env_var in sorted(os.environ.keys()):
if env_var.startswith(section_prefix):
key = env_var.replace(section_prefix, '').lower()
key = env_var.replace(section_prefix, '')
if key.endswith("_CMD"):
key = key[:-4]
key = key.lower()
_section[key] = self._get_env_var_option(section, key)
for key, val in _section.items():
for key, val in _section.items(): # type: ignore
try:
val = int(val)
except ValueError:
@ -372,6 +374,18 @@ class AirflowConfigParser(ConfigParser):
_section[key] = val
return _section
def write(self, fp, space_around_delimiters=True):
# This is based on the configparser.RawConfigParser.write method code to add support for
# reading options from environment variables.
if space_around_delimiters:
d = " {} ".format(self._delimiters[0]) # type: ignore
else:
d = self._delimiters[0] # type: ignore
if self._defaults:
self._write_section(fp, self.default_section, self._defaults.items(), d) # type: ignore
for section in self._sections:
self._write_section(fp, section, self.getsection(section).items(), d) # type: ignore
def as_dict(
self, display_source=False, display_sensitive=False, raw=False,
include_env=True, include_cmds=True) -> Dict[str, Dict[str, str]]:

Просмотреть файл

@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
import os
import tempfile
import unittest
import warnings
from collections import OrderedDict
@ -321,6 +322,21 @@ key3 = value3
test_conf.getsection('testsection')
)
def test_get_section_should_respect_cmd_env_variable(self):
with tempfile.NamedTemporaryFile(delete=False) as cmd_file:
cmd_file.write("#!/usr/bin/env bash\n".encode())
cmd_file.write("echo -n difficult_unpredictable_cat_password\n".encode())
cmd_file.flush()
os.chmod(cmd_file.name, 0o0555)
cmd_file.close()
with mock.patch.dict(
"os.environ", {"AIRFLOW__KUBERNETES__GIT_PASSWORD_CMD": cmd_file.name}
):
content = conf.getsection("kubernetes")
os.unlink(cmd_file.name)
self.assertEqual(content["git_password"], "difficult_unpredictable_cat_password")
def test_kubernetes_environment_variables_section(self):
test_config = '''
[kubernetes_environment_variables]
@ -535,6 +551,13 @@ notacommand = OK
self.assertEqual(value, fernet_key)
@mock.patch.dict("os.environ", {"AIRFLOW__CORE__DAGS_FOLDER": "/tmp/test_folder"})
def test_write_should_respect_env_variable(self):
with io.StringIO() as string_file:
conf.write(string_file)
content = string_file.getvalue()
self.assertIn("dags_folder = /tmp/test_folder", content)
def test_run_command(self):
write = r'sys.stdout.buffer.write("\u1000foo".encode("utf8"))'