[AIRFLOW-71] Add support for private Docker images
Pulling images from private Docker registries requires authentication, so additional parameters are added in order to perform the login step.
This commit is contained in:
Родитель
2fef9152be
Коммит
f101ff0063
|
@ -32,6 +32,7 @@ import sys
|
|||
#
|
||||
# ------------------------------------------------------------------------
|
||||
_hooks = {
|
||||
'docker_hook': ['DockerHook'],
|
||||
'ftp_hook': ['FTPHook'],
|
||||
'ftps_hook': ['FTPSHook'],
|
||||
'vertica_hook': ['VerticaHook'],
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from docker import Client
|
||||
from docker.errors import APIError
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.hooks.base_hook import BaseHook
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
|
||||
class DockerHook(BaseHook, LoggingMixin):
|
||||
"""
|
||||
Interact with a private Docker registry.
|
||||
|
||||
:param docker_conn_id: ID of the Airflow connection where
|
||||
credentials and extra configuration are stored
|
||||
:type docker_conn_id: str
|
||||
"""
|
||||
def __init__(self,
|
||||
docker_conn_id='docker_default',
|
||||
base_url=None,
|
||||
version=None,
|
||||
tls=None
|
||||
):
|
||||
if not base_url:
|
||||
raise AirflowException('No Docker base URL provided')
|
||||
if not version:
|
||||
raise AirflowException('No Docker API version provided')
|
||||
|
||||
conn = self.get_connection(docker_conn_id)
|
||||
if not conn.host:
|
||||
raise AirflowException('No Docker registry URL provided')
|
||||
if not conn.login:
|
||||
raise AirflowException('No username provided')
|
||||
extra_options = conn.extra_dejson
|
||||
|
||||
self.__base_url = base_url
|
||||
self.__version = version
|
||||
self.__tls = tls
|
||||
self.__registry = conn.host
|
||||
self.__username = conn.login
|
||||
self.__password = conn.password
|
||||
self.__email = extra_options.get('email')
|
||||
self.__reauth = False if extra_options.get('reauth') == 'no' else True
|
||||
|
||||
def get_conn(self):
|
||||
client = Client(
|
||||
base_url=self.__base_url,
|
||||
version=self.__version,
|
||||
tls=self.__tls
|
||||
)
|
||||
self.__login(client)
|
||||
return client
|
||||
|
||||
def __login(self, client):
|
||||
self.log.debug('Logging into Docker registry')
|
||||
try:
|
||||
client.login(
|
||||
username=self.__username,
|
||||
password=self.__password,
|
||||
registry=self.__registry,
|
||||
email=self.__email,
|
||||
reauth=self.__reauth
|
||||
)
|
||||
self.log.debug('Login successful')
|
||||
except APIError as docker_error:
|
||||
self.log.error('Docker registry login failed: %s', str(docker_error))
|
||||
raise AirflowException('Docker registry login failed: %s', str(docker_error))
|
|
@ -530,6 +530,7 @@ class Connection(Base, LoggingMixin):
|
|||
_extra = Column('extra', String(5000))
|
||||
|
||||
_types = [
|
||||
('docker', 'Docker Registry',),
|
||||
('fs', 'File (path)'),
|
||||
('ftp', 'FTP',),
|
||||
('google_cloud_platform', 'Google Cloud Platform'),
|
||||
|
@ -696,6 +697,9 @@ class Connection(Base, LoggingMixin):
|
|||
elif self.conn_type == 'wasb':
|
||||
from airflow.contrib.hooks.wasb_hook import WasbHook
|
||||
return WasbHook(wasb_conn_id=self.conn_id)
|
||||
elif self.conn_type == 'docker':
|
||||
from airflow.hooks.docker_hook import DockerHook
|
||||
return DockerHook(docker_conn_id=self.conn_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
|
||||
import json
|
||||
|
||||
from airflow.hooks.docker_hook import DockerHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
from airflow.utils.file import TemporaryDirectory
|
||||
from docker import APIClient as Client, tls
|
||||
from docker import Client, tls
|
||||
import ast
|
||||
|
||||
|
||||
|
@ -30,9 +31,14 @@ class DockerOperator(BaseOperator):
|
|||
that together exceed the default disk size of 10GB in a container. The path to the mounted
|
||||
directory can be accessed via the environment variable ``AIRFLOW_TMP_DIR``.
|
||||
|
||||
If a login to a private registry is required prior to pulling the image, a
|
||||
Docker connection needs to be configured in Airflow and the connection ID
|
||||
be provided with the parameter ``docker_conn_id``.
|
||||
|
||||
:param image: Docker image from which to create the container.
|
||||
:type image: str
|
||||
:param api_version: Remote API version.
|
||||
:param api_version: Remote API version. Set to ``auto`` to automatically
|
||||
detect the server's version.
|
||||
:type api_version: str
|
||||
:param command: Command to be run in the container.
|
||||
:type command: str or list
|
||||
|
@ -41,10 +47,11 @@ class DockerOperator(BaseOperator):
|
|||
https://docs.docker.com/engine/reference/run/#cpu-share-constraint
|
||||
:type cpus: float
|
||||
:param docker_url: URL of the host running the docker daemon.
|
||||
Default is unix://var/run/docker.sock
|
||||
:type docker_url: str
|
||||
:param environment: Environment variables to set in the container.
|
||||
:type environment: dict
|
||||
:param force_pull: Pull the docker image on every run.
|
||||
:param force_pull: Pull the docker image on every run. Default is false.
|
||||
:type force_pull: bool
|
||||
:param mem_limit: Maximum amount of memory the container can use. Either a float value, which
|
||||
represents the limit in bytes, or a string like ``128m`` or ``1g``.
|
||||
|
@ -78,6 +85,8 @@ class DockerOperator(BaseOperator):
|
|||
:type xcom_push: bool
|
||||
:param xcom_all: Push all the stdout or just the last line. The default is False (last line).
|
||||
:type xcom_all: bool
|
||||
:param docker_conn_id: ID of the Airflow connection to use
|
||||
:type docker_conn_id: str
|
||||
"""
|
||||
template_fields = ('command',)
|
||||
template_ext = ('.sh', '.bash',)
|
||||
|
@ -105,6 +114,7 @@ class DockerOperator(BaseOperator):
|
|||
working_dir=None,
|
||||
xcom_push=False,
|
||||
xcom_all=False,
|
||||
docker_conn_id=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
|
||||
|
@ -129,25 +139,32 @@ class DockerOperator(BaseOperator):
|
|||
self.working_dir = working_dir
|
||||
self.xcom_push_flag = xcom_push
|
||||
self.xcom_all = xcom_all
|
||||
self.docker_conn_id = docker_conn_id
|
||||
|
||||
self.cli = None
|
||||
self.container = None
|
||||
|
||||
def get_hook(self):
|
||||
return DockerHook(
|
||||
docker_conn_id=self.docker_conn_id,
|
||||
base_url=self.base_url,
|
||||
version=self.api_version,
|
||||
tls=self.__get_tls_config()
|
||||
)
|
||||
|
||||
def execute(self, context):
|
||||
self.log.info('Starting docker container from image %s', self.image)
|
||||
|
||||
tls_config = None
|
||||
if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key:
|
||||
tls_config = tls.TLSConfig(
|
||||
ca_cert=self.tls_ca_cert,
|
||||
client_cert=(self.tls_client_cert, self.tls_client_key),
|
||||
verify=True,
|
||||
ssl_version=self.tls_ssl_version,
|
||||
assert_hostname=self.tls_hostname
|
||||
)
|
||||
self.docker_url = self.docker_url.replace('tcp://', 'https://')
|
||||
tls_config = self.__get_tls_config()
|
||||
|
||||
self.cli = Client(base_url=self.docker_url, version=self.api_version, tls=tls_config)
|
||||
if self.docker_conn_id:
|
||||
self.cli = self.get_hook().get_conn()
|
||||
else:
|
||||
self.cli = Client(
|
||||
base_url=self.docker_url,
|
||||
version=self.api_version,
|
||||
tls=tls_config
|
||||
)
|
||||
|
||||
if ':' not in self.image:
|
||||
image = self.image + ':latest'
|
||||
|
@ -204,3 +221,16 @@ class DockerOperator(BaseOperator):
|
|||
if self.cli is not None:
|
||||
self.log.info('Stopping docker container')
|
||||
self.cli.stop(self.container['Id'])
|
||||
|
||||
def __get_tls_config(self):
|
||||
tls_config = None
|
||||
if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key:
|
||||
tls_config = tls.TLSConfig(
|
||||
ca_cert=self.tls_ca_cert,
|
||||
client_cert=(self.tls_client_cert, self.tls_client_key),
|
||||
verify=True,
|
||||
ssl_version=self.tls_ssl_version,
|
||||
assert_hostname=self.tls_hostname
|
||||
)
|
||||
self.docker_url = self.docker_url.replace('tcp://', 'https://')
|
||||
return tls_config
|
||||
|
|
|
@ -38,7 +38,14 @@
|
|||
'login': 'Username (or API Key)',
|
||||
'schema': 'Database'
|
||||
}
|
||||
}
|
||||
},
|
||||
docker: {
|
||||
hidden_fields: ['port', 'schema'],
|
||||
relabeling: {
|
||||
'host': 'Registry URL',
|
||||
'login': 'Username',
|
||||
},
|
||||
},
|
||||
}
|
||||
function connTypeChange(connectionType) {
|
||||
$("div.form-group").removeClass("hide");
|
||||
|
|
|
@ -216,6 +216,7 @@ Hooks
|
|||
:show-inheritance:
|
||||
:members:
|
||||
DbApiHook,
|
||||
DockerHook,
|
||||
HiveCliHook,
|
||||
HiveMetastoreHook,
|
||||
HiveServer2Hook,
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from airflow import configuration
|
||||
from airflow import models
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.utils import db
|
||||
|
||||
try:
|
||||
from airflow.hooks.docker_hook import DockerHook
|
||||
from docker import Client
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
try:
|
||||
import mock
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
|
||||
@mock.patch('airflow.hooks.docker_hook.Client', autospec=True)
|
||||
class DockerHookTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
configuration.load_test_config()
|
||||
db.merge_conn(
|
||||
models.Connection(
|
||||
conn_id='docker_default',
|
||||
conn_type='docker',
|
||||
host='some.docker.registry.com',
|
||||
login='some_user',
|
||||
password='some_p4$$w0rd'
|
||||
)
|
||||
)
|
||||
db.merge_conn(
|
||||
models.Connection(
|
||||
conn_id='docker_with_extras',
|
||||
conn_type='docker',
|
||||
host='some.docker.registry.com',
|
||||
login='some_user',
|
||||
password='some_p4$$w0rd',
|
||||
extra='{"email": "some@example.com", "reauth": "no"}'
|
||||
)
|
||||
)
|
||||
|
||||
def test_init_fails_when_no_base_url_given(self, _):
|
||||
with self.assertRaises(AirflowException):
|
||||
DockerHook(
|
||||
docker_conn_id='docker_default',
|
||||
version='auto',
|
||||
tls=None
|
||||
)
|
||||
|
||||
def test_init_fails_when_no_api_version_given(self, _):
|
||||
with self.assertRaises(AirflowException):
|
||||
DockerHook(
|
||||
docker_conn_id='docker_default',
|
||||
base_url='unix://var/run/docker.sock',
|
||||
tls=None
|
||||
)
|
||||
|
||||
def test_get_conn_override_defaults(self, docker_client_mock):
|
||||
hook = DockerHook(
|
||||
docker_conn_id='docker_default',
|
||||
base_url='https://index.docker.io/v1/',
|
||||
version='1.23',
|
||||
tls='someconfig'
|
||||
)
|
||||
hook.get_conn()
|
||||
docker_client_mock.assert_called_with(
|
||||
base_url='https://index.docker.io/v1/',
|
||||
version='1.23',
|
||||
tls='someconfig'
|
||||
)
|
||||
|
||||
def test_get_conn_with_standard_config(self, _):
|
||||
try:
|
||||
hook = DockerHook(
|
||||
docker_conn_id='docker_default',
|
||||
base_url='unix://var/run/docker.sock',
|
||||
version='auto'
|
||||
)
|
||||
client = hook.get_conn()
|
||||
self.assertIsNotNone(client)
|
||||
except:
|
||||
self.fail('Could not get connection from Airflow')
|
||||
|
||||
def test_get_conn_with_extra_config(self, _):
|
||||
try:
|
||||
hook = DockerHook(
|
||||
docker_conn_id='docker_with_extras',
|
||||
base_url='unix://var/run/docker.sock',
|
||||
version='auto'
|
||||
)
|
||||
client = hook.get_conn()
|
||||
self.assertIsNotNone(client)
|
||||
except:
|
||||
self.fail('Could not get connection from Airflow')
|
||||
|
||||
def test_conn_with_standard_config_passes_parameters(self, _):
|
||||
hook = DockerHook(
|
||||
docker_conn_id='docker_default',
|
||||
base_url='unix://var/run/docker.sock',
|
||||
version='auto'
|
||||
)
|
||||
client = hook.get_conn()
|
||||
client.login.assert_called_with(
|
||||
username='some_user',
|
||||
password='some_p4$$w0rd',
|
||||
registry='some.docker.registry.com',
|
||||
reauth=True,
|
||||
email=None
|
||||
)
|
||||
|
||||
def test_conn_with_extra_config_passes_parameters(self, _):
|
||||
hook = DockerHook(
|
||||
docker_conn_id='docker_with_extras',
|
||||
base_url='unix://var/run/docker.sock',
|
||||
version='auto'
|
||||
)
|
||||
client = hook.get_conn()
|
||||
client.login.assert_called_with(
|
||||
username='some_user',
|
||||
password='some_p4$$w0rd',
|
||||
registry='some.docker.registry.com',
|
||||
reauth=False,
|
||||
email='some@example.com'
|
||||
)
|
||||
|
||||
def test_conn_with_broken_config_missing_username_fails(self, _):
|
||||
db.merge_conn(
|
||||
models.Connection(
|
||||
conn_id='docker_without_username',
|
||||
conn_type='docker',
|
||||
host='some.docker.registry.com',
|
||||
password='some_p4$$w0rd',
|
||||
extra='{"email": "some@example.com"}'
|
||||
)
|
||||
)
|
||||
with self.assertRaises(AirflowException):
|
||||
hook = DockerHook(
|
||||
docker_conn_id='docker_without_username',
|
||||
base_url='unix://var/run/docker.sock',
|
||||
version='auto'
|
||||
)
|
||||
|
||||
def test_conn_with_broken_config_missing_host_fails(self, _):
|
||||
db.merge_conn(
|
||||
models.Connection(
|
||||
conn_id='docker_without_host',
|
||||
conn_type='docker',
|
||||
login='some_user',
|
||||
password='some_p4$$w0rd'
|
||||
)
|
||||
)
|
||||
with self.assertRaises(AirflowException):
|
||||
hook = DockerHook(
|
||||
docker_conn_id='docker_without_host',
|
||||
base_url='unix://var/run/docker.sock',
|
||||
version='auto'
|
||||
)
|
|
@ -17,7 +17,8 @@ import logging
|
|||
|
||||
try:
|
||||
from airflow.operators.docker_operator import DockerOperator
|
||||
from docker import APIClient as Client
|
||||
from airflow.hooks.docker_hook import DockerHook
|
||||
from docker import Client
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
@ -33,7 +34,6 @@ except ImportError:
|
|||
|
||||
|
||||
class DockerOperatorTestCase(unittest.TestCase):
|
||||
@unittest.skipIf(mock is None, 'mock package not present')
|
||||
@mock.patch('airflow.utils.file.mkdtemp')
|
||||
@mock.patch('airflow.operators.docker_operator.Client')
|
||||
def test_execute(self, client_class_mock, mkdtemp_mock):
|
||||
|
@ -77,7 +77,6 @@ class DockerOperatorTestCase(unittest.TestCase):
|
|||
client_mock.pull.assert_called_with('ubuntu:latest', stream=True)
|
||||
client_mock.wait.assert_called_with('some_id')
|
||||
|
||||
@unittest.skipIf(mock is None, 'mock package not present')
|
||||
@mock.patch('airflow.operators.docker_operator.tls.TLSConfig')
|
||||
@mock.patch('airflow.operators.docker_operator.Client')
|
||||
def test_execute_tls(self, client_class_mock, tls_class_mock):
|
||||
|
@ -105,7 +104,6 @@ class DockerOperatorTestCase(unittest.TestCase):
|
|||
client_class_mock.assert_called_with(base_url='https://127.0.0.1:2376', tls=tls_mock,
|
||||
version=None)
|
||||
|
||||
@unittest.skipIf(mock is None, 'mock package not present')
|
||||
@mock.patch('airflow.operators.docker_operator.Client')
|
||||
def test_execute_unicode_logs(self, client_class_mock):
|
||||
client_mock = mock.Mock(spec=Client)
|
||||
|
@ -128,7 +126,6 @@ class DockerOperatorTestCase(unittest.TestCase):
|
|||
logging.raiseExceptions = originalRaiseExceptions
|
||||
print_exception_mock.assert_not_called()
|
||||
|
||||
@unittest.skipIf(mock is None, 'mock package not present')
|
||||
@mock.patch('airflow.operators.docker_operator.Client')
|
||||
def test_execute_container_fails(self, client_class_mock):
|
||||
client_mock = mock.Mock(spec=Client)
|
||||
|
@ -146,7 +143,6 @@ class DockerOperatorTestCase(unittest.TestCase):
|
|||
with self.assertRaises(AirflowException):
|
||||
operator.execute(None)
|
||||
|
||||
@unittest.skipIf(mock is None, 'mock package not present')
|
||||
def test_on_kill(self):
|
||||
client_mock = mock.Mock(spec=Client)
|
||||
|
||||
|
@ -158,6 +154,80 @@ class DockerOperatorTestCase(unittest.TestCase):
|
|||
|
||||
client_mock.stop.assert_called_with('some_id')
|
||||
|
||||
@mock.patch('airflow.operators.docker_operator.Client')
|
||||
def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock):
|
||||
# Mock out a Docker client, so operations don't raise errors
|
||||
client_mock = mock.Mock(name='DockerOperator.Client mock', spec=Client)
|
||||
client_mock.images.return_value = []
|
||||
client_mock.create_container.return_value = {'Id': 'some_id'}
|
||||
client_mock.logs.return_value = []
|
||||
client_mock.pull.return_value = []
|
||||
client_mock.wait.return_value = 0
|
||||
operator_client_mock.return_value = client_mock
|
||||
|
||||
# Create the DockerOperator
|
||||
operator = DockerOperator(
|
||||
image='publicregistry/someimage',
|
||||
owner='unittest',
|
||||
task_id='unittest'
|
||||
)
|
||||
|
||||
# Mock out the DockerHook
|
||||
hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook)
|
||||
hook_mock.get_conn.return_value = client_mock
|
||||
operator.get_hook = mock.Mock(
|
||||
name='DockerOperator.get_hook mock',
|
||||
spec=DockerOperator.get_hook,
|
||||
return_value=hook_mock
|
||||
)
|
||||
|
||||
operator.execute(None)
|
||||
self.assertEqual(
|
||||
operator.get_hook.call_count, 0,
|
||||
'Hook called though no docker_conn_id configured'
|
||||
)
|
||||
|
||||
@mock.patch('airflow.operators.docker_operator.Client')
|
||||
def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock):
|
||||
# Mock out a Docker client, so operations don't raise errors
|
||||
client_mock = mock.Mock(name='DockerOperator.Client mock', spec=Client)
|
||||
client_mock.images.return_value = []
|
||||
client_mock.create_container.return_value = {'Id': 'some_id'}
|
||||
client_mock.logs.return_value = []
|
||||
client_mock.pull.return_value = []
|
||||
client_mock.wait.return_value = 0
|
||||
operator_client_mock.return_value = client_mock
|
||||
|
||||
# Create the DockerOperator
|
||||
operator = DockerOperator(
|
||||
image='publicregistry/someimage',
|
||||
owner='unittest',
|
||||
task_id='unittest',
|
||||
docker_conn_id='some_conn_id'
|
||||
)
|
||||
|
||||
# Mock out the DockerHook
|
||||
hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook)
|
||||
hook_mock.get_conn.return_value = client_mock
|
||||
operator.get_hook = mock.Mock(
|
||||
name='DockerOperator.get_hook mock',
|
||||
spec=DockerOperator.get_hook,
|
||||
return_value=hook_mock
|
||||
)
|
||||
|
||||
operator.execute(None)
|
||||
self.assertEqual(
|
||||
operator_client_mock.call_count, 0,
|
||||
'Client was called on the operator instead of the hook'
|
||||
)
|
||||
self.assertEqual(
|
||||
operator.get_hook.call_count, 1,
|
||||
'Hook was not called although docker_conn_id configured'
|
||||
)
|
||||
self.assertEqual(
|
||||
client_mock.pull.call_count, 1,
|
||||
'Image was not pulled using operator client'
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче