incubator-airflow/airflow/providers_manager.py

409 строки
18 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Manages all providers."""
import fnmatch
import importlib
import json
import logging
import os
from collections import OrderedDict
from typing import Any, Dict, NamedTuple, Set
import jsonschema
import yaml
from wtforms import Field
from airflow.utils.entry_points import entry_points_with_dist
try:
import importlib.resources as importlib_resources
except ImportError:
# Try back-ported to PY<37 `importlib_resources`.
import importlib_resources
log = logging.getLogger(__name__)
def _create_provider_schema_validator():
"""Creates JSON schema validator from the provider.yaml.schema.json"""
schema = json.loads(importlib_resources.read_text('airflow', 'provider.yaml.schema.json'))
cls = jsonschema.validators.validator_for(schema)
validator = cls(schema)
return validator
def _create_customized_form_field_behaviours_schema_validator():
"""Creates JSON schema validator from the customized_form_field_behaviours.schema.json"""
schema = json.loads(
importlib_resources.read_text('airflow', 'customized_form_field_behaviours.schema.json')
)
cls = jsonschema.validators.validator_for(schema)
validator = cls(schema)
return validator
class ProviderInfo(NamedTuple):
"""Provider information"""
version: str
provider_info: Dict
class HookInfo(NamedTuple):
"""Hook information"""
connection_class: str
connection_id_attribute_name: str
package_name: str
hook_name: str
class ConnectionFormWidgetInfo(NamedTuple):
"""Connection Form Widget information"""
connection_class: str
package_name: str
field: Field
class ProvidersManager:
"""
Manages all provider packages. This is a Singleton class. The first time it is
instantiated, it discovers all available providers in installed packages and
local source folders (if airflow is run from sources).
"""
_instance = None
resource_version = "0"
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# Keeps dict of providers keyed by module name
self._provider_dict: Dict[str, ProviderInfo] = {}
# Keeps dict of hooks keyed by connection type
self._hooks_dict: Dict[str, HookInfo] = {}
# Keeps methods that should be used to add custom widgets tuple of keyed by name of the extra field
self._connection_form_widgets: Dict[str, ConnectionFormWidgetInfo] = {}
# Customizations for javascript fields are kept here
self._field_behaviours: Dict[str, Dict] = {}
self._extra_link_class_name_set: Set[str] = set()
self._provider_schema_validator = _create_provider_schema_validator()
self._customized_form_fields_schema_validator = (
_create_customized_form_field_behaviours_schema_validator()
)
self._initialized = False
def initialize_providers_manager(self):
"""Lazy initialization of provider data."""
# We cannot use @cache here because it does not work during pytests, apparently each test
# runs it it's own namespace and ProvidersManager is a different object in each namespace
# even if it is singleton but @cache on the initialize_providers_manager message still works in the
# way that it is called only once for one of the objects (at least this is how it looks like
# from running tests)
if self._initialized:
return
# Local source folders are loaded first. They should take precedence over the package ones for
# Development purpose. In production provider.yaml files are not present in the 'airflow" directory
# So there is no risk we are going to override package provider accidentally. This can only happen
# in case of local development
self._discover_all_airflow_builtin_providers_from_local_sources()
self._discover_all_providers_from_packages()
self._discover_hooks()
self._provider_dict = OrderedDict(sorted(self._provider_dict.items())) # noqa
self._hooks_dict = OrderedDict(sorted(self._hooks_dict.items())) # noqa
self._connection_form_widgets = OrderedDict(sorted(self._connection_form_widgets.items())) # noqa
self._field_behaviours = OrderedDict(sorted(self._field_behaviours.items())) # noqa
self._discover_extra_links()
self._initialized = True
def _discover_all_providers_from_packages(self) -> None:
"""
Discovers all providers by scanning packages installed. The list of providers should be returned
via the 'apache_airflow_provider' entrypoint as a dictionary conforming to the
'airflow/provider.yaml.schema.json' schema.
"""
for entry_point, dist in entry_points_with_dist('apache_airflow_provider'):
package_name = dist.metadata['name']
if self._provider_dict.get(package_name) is not None:
continue
log.debug("Loading %s from package %s", entry_point, package_name)
version = dist.version
provider_info = entry_point.load()()
self._provider_schema_validator.validate(provider_info)
provider_info_package_name = provider_info['package-name']
if package_name != provider_info_package_name:
raise Exception(
f"The package '{package_name}' from setuptools and "
f"{provider_info_package_name} do not match. Please make sure they are aligned"
)
if package_name not in self._provider_dict:
self._provider_dict[package_name] = (version, provider_info)
else:
log.warning(
"The provider for package '%s' could not be registered from because providers for that "
"package name have already been registered",
package_name,
)
def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None:
"""
Finds all built-in airflow providers if airflow is run from the local sources.
It finds `provider.yaml` files for all such providers and registers the providers using those.
This 'provider.yaml' scanning takes precedence over scanning packages installed
in case you have both sources and packages installed, the providers will be loaded from
the "airflow" sources rather than from the packages.
"""
try:
import airflow.providers
except ImportError:
log.info("You have no providers installed.")
return
try:
for path in airflow.providers.__path__:
self._add_provider_info_from_local_source_files_on_path(path)
except Exception as e: # noqa pylint: disable=broad-except
log.warning("Error when loading 'provider.yaml' files from airflow sources: %s", e)
def _add_provider_info_from_local_source_files_on_path(self, path) -> None:
"""
Finds all the provider.yaml files in the directory specified.
:param path: path where to look for provider.yaml files
"""
root_path = path
for folder, subdirs, files in os.walk(path, topdown=True):
for filename in fnmatch.filter(files, "provider.yaml"):
package_name = "apache-airflow-providers" + folder[len(root_path) :].replace(os.sep, "-")
# We are skipping discovering snowflake because of snowflake monkeypatching problem
# This is only for local development - it has no impact for the packaged snowflake provider
# That should work on its own
# https://github.com/apache/airflow/issues/12881
# Once this is back, we can remove this limitation.
if package_name != "apache-airflow-providers-snowflake":
self._add_provider_info_from_local_source_file(
os.path.join(folder, filename), package_name
)
subdirs[:] = []
def _add_provider_info_from_local_source_file(self, path, package_name) -> None:
"""
Parses found provider.yaml file and adds found provider to the dictionary.
:param path: full file path of the provider.yaml file
:param package_name: name of the package
"""
try:
log.debug("Loading %s from %s", package_name, path)
with open(path) as provider_yaml_file:
provider_info = yaml.safe_load(provider_yaml_file)
self._provider_schema_validator.validate(provider_info)
version = provider_info['versions'][0]
if package_name not in self._provider_dict:
self._provider_dict[package_name] = (version, provider_info)
else:
log.warning(
"The providers for package '%s' could not be registered because providers for that "
"package name have already been registered",
package_name,
)
except Exception as e: # noqa pylint: disable=broad-except
log.warning("Error when loading '%s': %s", path, e)
def _discover_hooks(self) -> None:
"""Retrieves all connections defined in the providers"""
for name, provider in self._provider_dict.items():
provider_package = name
hook_class_names = provider[1].get("hook-class-names")
if hook_class_names:
for hook_class_name in hook_class_names:
self._add_hook(hook_class_name, provider_package)
@staticmethod
def _get_attr(obj: Any, attr_name: str):
"""Retrieves attributes of an object, or warns if not found"""
if not hasattr(obj, attr_name):
log.warning("The '%s' is missing %s attribute and cannot be registered", obj, attr_name)
return None
return getattr(obj, attr_name)
def _add_hook(self, hook_class_name: str, provider_package: str) -> None:
"""
Adds hook class name to list of hooks
:param hook_class_name: name of the Hook class
:param provider_package: provider package adding the hook
"""
if provider_package.startswith("apache-airflow"):
provider_path = provider_package[len("apache-") :].replace("-", ".")
if not hook_class_name.startswith(provider_path):
log.warning(
"Sanity check failed when importing '%s' from '%s' package. It should start with '%s'",
hook_class_name,
provider_package,
provider_path,
)
return
if hook_class_name in self._hooks_dict:
log.warning(
"The hook_class '%s' has been already registered.",
hook_class_name,
)
return
try:
module, class_name = hook_class_name.rsplit('.', maxsplit=1)
hook_class = getattr(importlib.import_module(module), class_name)
# Do not use attr here. We want to check only direct class fields not those
# inherited from parent hook. This way we add form fields only once for the whole
# hierarchy and we add it only from the parent hook that provides those!
if 'get_connection_form_widgets' in hook_class.__dict__:
widgets = hook_class.get_connection_form_widgets()
if widgets:
self._add_widgets(provider_package, hook_class, widgets)
if 'get_ui_field_behaviour' in hook_class.__dict__:
field_behaviours = hook_class.get_ui_field_behaviour()
if field_behaviours:
self._add_customized_fields(provider_package, hook_class, field_behaviours)
except Exception as e: # noqa pylint: disable=broad-except
log.warning(
"Exception when importing '%s' from '%s' package: %s",
hook_class_name,
provider_package,
e,
)
return
conn_type: str = self._get_attr(hook_class, 'conn_type')
connection_id_attribute_name: str = self._get_attr(hook_class, 'conn_name_attr')
hook_name: str = self._get_attr(hook_class, 'hook_name')
if not conn_type or not connection_id_attribute_name or not hook_name:
return
self._hooks_dict[conn_type] = HookInfo(
hook_class_name,
connection_id_attribute_name,
provider_package,
hook_name,
)
def _add_widgets(self, package_name: str, hook_class: type, widgets: Dict[str, Field]):
for field_name, field in widgets.items():
if not field_name.startswith("extra__"):
log.warning(
"The field %s from class %s does not start with 'extra__'. Ignoring it.",
field_name,
hook_class.__name__,
)
continue
if field_name in self._connection_form_widgets:
log.warning(
"The field %s from class %s has already been added by another provider. Ignoring it.",
field_name,
hook_class.__name__,
)
# In case of inherited hooks this might be happening several times
continue
self._connection_form_widgets[field_name] = ConnectionFormWidgetInfo(
hook_class.__name__, package_name, field
)
def _add_customized_fields(self, package_name: str, hook_class: type, customized_fields: Dict):
try:
connection_type = getattr(hook_class, "conn_type")
self._customized_form_fields_schema_validator.validate(customized_fields)
if connection_type in self._field_behaviours:
log.warning(
"The connection_type %s from package %s and class %s has already been added "
"by another provider. Ignoring it.",
connection_type,
package_name,
hook_class.__name__,
)
return
self._field_behaviours[connection_type] = customized_fields
except Exception as e: # noqa pylint: disable=broad-except
log.warning(
"Error when loading customized fields from package '%s' hook class '%s': %s",
package_name,
hook_class.__name__,
e,
)
def _discover_extra_links(self) -> None:
"""Retrieves all extra links defined in the providers"""
for provider_package, (_, provider) in self._provider_dict.items():
if provider.get("extra-links"):
for extra_link in provider["extra-links"]:
self._add_extra_link(extra_link, provider_package)
def _add_extra_link(self, extra_link_class_name, provider_package) -> None:
"""
Adds extra link class name to the list of classes
:param extra_link_class_name: name of the class to add
:param provider_package: provider package adding the link
:return:
"""
if provider_package.startswith("apache-airflow"):
provider_path = provider_package[len("apache-") :].replace("-", ".")
if not extra_link_class_name.startswith(provider_path):
log.warning(
"Sanity check failed when importing '%s' from '%s' package. It should start with '%s'",
extra_link_class_name,
provider_package,
provider_path,
)
return
self._extra_link_class_name_set.add(extra_link_class_name)
@property
def providers(self) -> Dict[str, ProviderInfo]:
"""Returns information about available providers."""
self.initialize_providers_manager()
return self._provider_dict
@property
def hooks(self) -> Dict[str, HookInfo]:
"""Returns dictionary of connection_type-to-hook mapping"""
self.initialize_providers_manager()
return self._hooks_dict
@property
def extra_links_class_names(self):
"""Returns set of extra link class names."""
self.initialize_providers_manager()
return sorted(self._extra_link_class_name_set)
@property
def connection_form_widgets(self) -> Dict[str, ConnectionFormWidgetInfo]:
"""Returns widgets for connection forms."""
self.initialize_providers_manager()
return self._connection_form_widgets
@property
def field_behaviours(self) -> Dict[str, Dict]:
"""Returns dictionary with field behaviours for connection types."""
self.initialize_providers_manager()
return self._field_behaviours