[AIRFLOW=1314] Basic Kubernetes Mode

This commit is contained in:
dimberman 2017-06-27 09:55:19 -07:00 коммит произвёл Fokko Driesprong
Родитель f520990fe0
Коммит 5821320880
28 изменённых файлов: 1618 добавлений и 416 удалений

Просмотреть файл

@ -85,6 +85,7 @@ from airflow import sensors # noqa: E402
from airflow import hooks
from airflow import executors
from airflow import macros
from airflow import contrib
operators._integrate_plugins()
sensors._integrate_plugins() # noqa: E402

Просмотреть файл

@ -0,0 +1,252 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import calendar
import logging
import time
import os
import multiprocessing
from airflow.contrib.kubernetes.kubernetes_pod_builder import KubernetesPodBuilder
from airflow.contrib.kubernetes.kubernetes_helper import KubernetesHelper
from queue import Queue
from kubernetes import watch
from airflow import settings
from airflow.contrib.kubernetes.kubernetes_request_factory import SimplePodRequestFactory
from airflow.executors.base_executor import BaseExecutor
from airflow.models import TaskInstance
from airflow.utils.state import State
from airflow import configuration
import json
# TODO this is just for proof of concept. remove before merging.
def _prep_command_for_container(command):
"""
When creating a kubernetes pod, the yaml expects the command
in the form of ["cmd","arg","arg","arg"...]
This function splits the command string into tokens
and then matches it to the convention.
:param command:
:return:
"""
return '"' + '","'.join(command.split(' ')[1:]) + '"'
class KubernetesJobWatcher(multiprocessing.Process, object):
def __init__(self, watch_function, namespace, result_queue, watcher_queue):
self.logger = logging.getLogger(__name__)
multiprocessing.Process.__init__(self)
self.result_queue = result_queue
self._watch_function = watch_function
self._watch = watch.Watch()
self.namespace = namespace
self.watcher_queue = watcher_queue
def run(self):
self.logger.info("Event: and now my watch begins")
self.logger.info("Event: proof of image change")
self.logger.info("Event: running {} with {}".format(str(self._watch_function),
self.namespace))
for event in self._watch.stream(self._watch_function, self.namespace):
task= event['object']
self.logger.info("Event: {} had an event of type {}".format(task.metadata.name,
event['type']))
self.process_status(task.metadata.name, task.status.phase)
def process_status(self, job_id, status):
if status == 'Pending':
self.logger.info("Event: {} Pending".format(job_id))
elif status == 'Failed':
self.logger.info("Event: {} Failed".format(job_id))
self.watcher_queue.put((job_id, State.FAILED))
elif status == 'Succeeded':
self.logger.info("Event: {} Succeeded".format(job_id))
self.watcher_queue.put((job_id, None))
elif status == 'Running':
self.logger.info("Event: {} is Running".format(job_id))
else:
self.logger.info("Event: Invalid state {} on job {}".format(status, job_id))
class AirflowKubernetesScheduler(object):
def __init__(self,
task_queue,
result_queue,
running):
self.logger = logging.getLogger(__name__)
self.logger.info("creating kubernetes executor")
self.task_queue = task_queue
self.namespace = os.environ['k8s_POD_NAMESPACE']
self.logger.info("k8s: using namespace {}".format(self.namespace))
self.result_queue = result_queue
self.current_jobs = {}
self.running = running
self._task_counter = 0
self.watcher_queue = multiprocessing.Queue()
self.helper = KubernetesHelper()
w = KubernetesJobWatcher(self.helper.pod_api.list_namespaced_pod, self.namespace,
self.result_queue, self.watcher_queue)
w.start()
def run_next(self, next_job):
"""
The run_next command will check the task_queue for any un-run jobs.
It will then create a unique job-id, launch that job in the cluster,
and store relevent info in the current_jobs map so we can track the job's
status
:return:
"""
self.logger.info('k8s: job is {}'.format(str(next_job)))
(key, command) = next_job
self.logger.info("running for command {}".format(command))
epoch_time = calendar.timegm(time.gmtime())
command_list = ["/usr/local/airflow/entrypoint.sh"] + command.split()[1:] + \
['-km']
self._set_host_id(key)
pod_id = self._create_job_id_from_key(key=key, epoch_time=epoch_time)
self.current_jobs[pod_id] = key
image = configuration.get('core','k8s_image')
print("k8s: launching image {}".format(image))
pod = KubernetesPodBuilder(
image= image,
cmds=command_list,
kub_req_factory=SimplePodRequestFactory(),
namespace=self.namespace)
pod.add_name(pod_id)
pod.launch()
self._task_counter += 1
self.logger.info("k8s: Job created!")
def delete_job(self, key):
job_id = self.current_jobs[key]
self.helper.delete_job(job_id, namespace=self.namespace)
def sync(self):
"""
The sync function checks the status of all currently running kubernetes jobs.
If a job is completed, it's status is placed in the result queue to
be sent back to the scheduler.
:return:
"""
while not self.watcher_queue.empty():
self.end_task()
def end_task(self):
job_id, state = self.watcher_queue.get()
if job_id in self.current_jobs:
key = self.current_jobs[job_id]
self.logger.info("finishing job {}".format(key))
if state:
self.result_queue.put((key, state))
self.current_jobs.pop(job_id)
self.running.pop(key)
def _create_job_id_from_key(self, key, epoch_time):
"""
Kubernetes pod names must unique and match specific conventions
(i.e. no spaces, period, etc.)
This function creates a unique name using the epoch time and internal counter
:param key:
:param epoch_time:
:return:
"""
keystr = '-'.join([str(x).replace(' ', '-') for x in key[:2]])
job_fields = [keystr, str(self._task_counter), str(epoch_time)]
unformatted_job_id = '-'.join(job_fields)
job_id = unformatted_job_id.replace('_', '-')
return job_id
def _set_host_id(self, key):
(dag_id, task_id, ex_time) = key
session = settings.Session()
item = session.query(TaskInstance) \
.filter_by(dag_id=dag_id, task_id=task_id, execution_date=ex_time).one()
host_id = item.hostname
print("host is {}".format(host_id))
class KubernetesExecutor(BaseExecutor):
def start(self):
self.logger.info('k8s: starting kubernetes executor')
self.task_queue = Queue()
self._session = settings.Session()
self.result_queue = Queue()
self.kub_client = AirflowKubernetesScheduler(self.task_queue,
self.result_queue,
running=self.running)
def sync(self):
self.kub_client.sync()
while not self.result_queue.empty():
results = self.result_queue.get()
self.logger.info("reporting {}".format(results))
self.change_state(*results)
# TODO this could be a job_counter based on max jobs a user wants
if len(self.kub_client.current_jobs) > 3:
self.logger.info("currently a job is running")
else:
self.logger.info("queue ready, running next")
if not self.task_queue.empty():
(key, command) = self.task_queue.get()
self.kub_client.run_next((key, command))
def terminate(self):
pass
def change_state(self, key, state):
self.logger.info("k8s: setting state of {} to {}".format(key, state))
if state != State.RUNNING:
self.kub_client.delete_job(key)
self.running.pop(key)
self.event_buffer[key] = state
(dag_id, task_id, ex_time) = key
item = self._session.query(TaskInstance).filter_by(
dag_id=dag_id,
task_id=task_id,
execution_date=ex_time).one()
if item.state == State.RUNNING or item.state == State.QUEUED:
item.state = state
self._session.add(item)
self._session.commit()
def end(self):
self.logger.info('ending kube executor')
self.task_queue.join()
def execute_async(self, key, command, queue=None):
self.logger.info("k8s: adding task {} with command {}".format(key, command))
self.task_queue.put((key, command))

Просмотреть файл

@ -1,16 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# -*- coding: utf-8 -*-
#
# http://www.apache.org/licenses/LICENSE-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from airflow import dag_importer
dag_importer.import_dags()

Просмотреть файл

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
from kubernetes import client, config
class KubernetesHelper(object):
def __init__(self):
config.load_incluster_config()
self.job_api = client.BatchV1Api()
self.pod_api = client.CoreV1Api()
def launch_job(self, pod_info, namespace):
dep = yaml.load(pod_info)
resp = self.job_api.create_namespaced_job(body=dep, namespace=namespace)
return resp
def get_status(self, pod_id, namespace):
return self.job_api.read_namespaced_job(pod_id, namespace).status
def delete_job(self, job_id, namespace):
body = client.V1DeleteOptions()
self.job_api.delete_namespaced_job(name=job_id, namespace=namespace, body=body)

Просмотреть файл

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from kubernetes import client, config
import json
import logging
class KubernetesJobBuilder:
def __init__(
self,
image,
cmds,
namespace,
kub_req_factory=None
):
self.image = image
self.cmds = cmds
self.kub_req_factory = kub_req_factory
self.namespace = namespace
self.logger = logging.getLogger(self.__class__.__name__)
self.envs = {}
self.labels = {}
self.secrets = {}
self.node_selectors = []
self.name = None
def add_env_variables(self, env):
self.envs = env
def add_secrets(self, secrets):
self.secrets = secrets
def add_labels(self, labels):
self.labels = labels
def add_name(self, name):
self.name = name
def set_namespace(self, namespace):
self.namespace = namespace
def launch(self):
"""
Launches the pod synchronously and waits for completion.
"""
k8s_beta = self._kube_client()
req = self.kub_req_factory.create(self)
print(json.dumps(req))
resp = k8s_beta.create_namespaced_job(body=req, namespace=self.namespace)
self.logger.info("Job created. status='%s', yaml:\n%s",
str(resp.status), str(req))
def _kube_client(self):
config.load_incluster_config()
return client.BatchV1Api()
def _execution_finished(self):
k8s_beta = self._kube_client()
resp = k8s_beta.read_namespaced_job_status(self.name, namespace=self.namespace)
self.logger.info('status : ' + str(resp.status))
if resp.status.phase == 'Failed':
raise Exception("Job " + self.name + " failed!")
return resp.status.phase != 'Running'

Просмотреть файл

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from kubernetes import client, config
import json
import logging
class KubernetesPodBuilder:
def __init__(
self,
image,
cmds,
namespace,
kub_req_factory=None
):
self.image = image
self.cmds = cmds
self.kub_req_factory = kub_req_factory
self.namespace = namespace
self.logger = logging.getLogger(self.__class__.__name__)
self.envs = {}
self.labels = {}
self.secrets = {}
self.node_selectors = []
self.name = None
def add_env_variables(self, env):
self.envs = env
def add_secrets(self, secrets):
self.secrets = secrets
def add_labels(self, labels):
self.labels = labels
def add_name(self, name):
self.name = name
def set_namespace(self, namespace):
self.namespace = namespace
def launch(self):
"""
Launches the pod synchronously and waits for completion.
"""
k8s_beta = self._kube_client()
req = self.kub_req_factory.create(self)
print(json.dumps(req))
resp = k8s_beta.create_namespaced_pod(body=req, namespace=self.namespace)
self.logger.info("Job created. status='%s', yaml:\n%s",
str(resp.status), str(req))
def _kube_client(self):
config.load_incluster_config()
return client.CoreV1Api()
def _execution_finished(self):
k8s_beta = self._kube_client()
resp = k8s_beta.read_namespaced_job_status(self.name, namespace=self.namespace)
self.logger.info('status : ' + str(resp.status))
if resp.status.phase == 'Failed':
raise Exception("Job " + self.name + " failed!")
return resp.status.phase != 'Running'

Просмотреть файл

@ -1,16 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# -*- coding: utf-8 -*-
#
# http://www.apache.org/licenses/LICENSE-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .kubernetes_request_factory import *
from .job_request_factory import *
from .pod_request_factory import *

Просмотреть файл

@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import yaml
from .kubernetes_request_factory import *
class SimpleJobRequestFactory(KubernetesRequestFactory):
"""
Request generator for a simple pod.
"""
def __init__(self):
pass
_yaml = """apiVersion: batch/v1
kind: Job
metadata:
name: name
spec:
template:
metadata:
name: name
spec:
containers:
- name: base
image: airflow-slave:latest
command: ["/usr/local/airflow/entrypoint.sh", "/bin/bash sleep 25"]
volumeMounts:
- name: shared-data
mountPath: "/usr/local/airflow/dags"
restartPolicy: Never
"""
def create(self, pod):
req = yaml.load(self._yaml)
sub_req = req['spec']['template']
extract_name(pod, sub_req)
extract_labels(pod, sub_req)
extract_image(pod, sub_req)
extract_cmds(pod, sub_req)
if len(pod.node_selectors) > 0:
extract_node_selector(pod, sub_req)
extract_secrets(pod, sub_req)
print("attaching volume mounts")
attach_volume_mounts(sub_req)
return req

Просмотреть файл

@ -1,165 +1,107 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# -*- coding: utf-8 -*-
#
# http://www.apache.org/licenses/LICENSE-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from abc import ABCMeta, abstractmethod
import six
from airflow import dag_importer
class KubernetesRequestFactory:
class KubernetesRequestFactory():
"""
Create requests to be sent to kube API.
Extend this class to talk to kubernetes and generate your specific resources.
This is equivalent of generating yaml files that can be used by `kubectl`
Create requests to be sent to kube API. Extend this class
to talk to kubernetes and generate your specific resources.
This is equivalent of generating yaml files that can be used
by `kubectl`
"""
__metaclass__ = ABCMeta
@abstractmethod
def create(self, pod):
"""
Creates the request for kubernetes API.
Creates the request for kubernetes API.
:param pod: The pod object
:param pod: The pod object
"""
pass
@staticmethod
def extract_image(pod, req):
req['spec']['containers'][0]['image'] = pod.image
@staticmethod
def extract_image_pull_policy(pod, req):
if pod.image_pull_policy:
req['spec']['containers'][0]['imagePullPolicy'] = pod.image_pull_policy
def extract_image(pod, req):
req['spec']['containers'][0]['image'] = pod.image
@staticmethod
def add_secret_to_env(env, secret):
env.append({
'name': secret.deploy_target,
'valueFrom': {
'secretKeyRef': {
'name': secret.secret,
'key': secret.key
}
def add_secret_to_env(env, secret):
env.append({
'name': secret.deploy_target,
'valueFrom': {
'secretKeyRef': {
'name': secret.secret,
'key': secret.key
}
}
})
def extract_labels(pod, req):
for k in pod.labels.keys():
req['metadata']['labels'][k] = pod.labels[k]
def extract_cmds(pod, req):
req['spec']['containers'][0]['command'] = pod.cmds
def extract_node_selector(pod, req):
req['spec']['nodeSelector'] = pod.node_selectors
def extract_secrets(pod, req):
env_secrets = [s for s in pod.secrets if s.deploy_type == 'env']
if len(pod.envs) > 0 or len(env_secrets) > 0:
env = []
for k in pod.envs.keys():
env.append({'name': k, 'value': pod.envs[k]})
for secret in env_secrets:
add_secret_to_env(env, secret)
req['spec']['containers'][0]['env'] = env
def attach_volume_mounts(req):
logging.info("preparing to import dags")
dag_importer.import_dags()
logging.info("using file mount {}".format(dag_importer.dag_import_spec))
req['spec']['volumes'] = [dag_importer.dag_import_spec]
def extract_name(pod, req):
req['metadata']['name'] = pod.name
def extract_volume_secrets(pod, req):
vol_secrets = [s for s in pod.secrets if s.deploy_type == 'volume']
if any(vol_secrets):
req['spec']['containers'][0]['volumeMounts'] = []
req['spec']['volumes'] = []
for idx, vol in enumerate(vol_secrets):
vol_id = 'secretvol' + str(idx)
req['spec']['containers'][0]['volumeMounts'].append({
'mountPath': vol.deploy_target,
'name': vol_id,
'readOnly': True
})
req['spec']['volumes'].append({
'name': vol_id,
'secret': {
'secretName': vol.secret
}
})
@staticmethod
def extract_labels(pod, req):
req['metadata']['labels'] = req['metadata'].get('labels', {})
for k, v in six.iteritems(pod.labels):
req['metadata']['labels'][k] = v
@staticmethod
def extract_cmds(pod, req):
req['spec']['containers'][0]['command'] = pod.cmds
@staticmethod
def extract_args(pod, req):
req['spec']['containers'][0]['args'] = pod.args
@staticmethod
def extract_node_selector(pod, req):
if len(pod.node_selectors) > 0:
req['spec']['nodeSelector'] = pod.node_selectors
@staticmethod
def attach_volumes(pod, req):
req['spec']['volumes'] = pod.volumes
@staticmethod
def attach_volume_mounts(pod, req):
if len(pod.volume_mounts) > 0:
req['spec']['containers'][0]['volumeMounts'] = (
req['spec']['containers'][0].get('volumeMounts', []))
req['spec']['containers'][0]['volumeMounts'].extend(pod.volume_mounts)
@staticmethod
def extract_name(pod, req):
req['metadata']['name'] = pod.name
@staticmethod
def extract_volume_secrets(pod, req):
vol_secrets = [s for s in pod.secrets if s.deploy_type == 'volume']
if any(vol_secrets):
req['spec']['containers'][0]['volumeMounts'] = []
req['spec']['volumes'] = []
for idx, vol in enumerate(vol_secrets):
vol_id = 'secretvol' + str(idx)
req['spec']['containers'][0]['volumeMounts'].append({
'mountPath': vol.deploy_target,
'name': vol_id,
'readOnly': True
})
req['spec']['volumes'].append({
'name': vol_id,
'secret': {
'secretName': vol.secret
}
})
@staticmethod
def extract_env_and_secrets(pod, req):
env_secrets = [s for s in pod.secrets if s.deploy_type == 'env']
if len(pod.envs) > 0 or len(env_secrets) > 0:
env = []
for k in pod.envs.keys():
env.append({'name': k, 'value': pod.envs[k]})
for secret in env_secrets:
KubernetesRequestFactory.add_secret_to_env(env, secret)
req['spec']['containers'][0]['env'] = env
@staticmethod
def extract_resources(pod, req):
if not pod.resources or pod.resources.is_empty_resource_request():
return
req['spec']['containers'][0]['resources'] = {}
if pod.resources.has_requests():
req['spec']['containers'][0]['resources']['requests'] = {}
if pod.resources.request_memory:
req['spec']['containers'][0]['resources']['requests'][
'memory'] = pod.resources.request_memory
if pod.resources.request_cpu:
req['spec']['containers'][0]['resources']['requests'][
'cpu'] = pod.resources.request_cpu
if pod.resources.has_limits():
req['spec']['containers'][0]['resources']['limits'] = {}
if pod.resources.request_memory:
req['spec']['containers'][0]['resources']['limits'][
'memory'] = pod.resources.limit_memory
if pod.resources.request_cpu:
req['spec']['containers'][0]['resources']['limits'][
'cpu'] = pod.resources.limit_cpu
@staticmethod
def extract_init_containers(pod, req):
if pod.init_containers:
req['spec']['initContainers'] = pod.init_containers
@staticmethod
def extract_service_account_name(pod, req):
if pod.service_account_name:
req['spec']['serviceAccountName'] = pod.service_account_name
@staticmethod
def extract_image_pull_secrets(pod, req):
if pod.image_pull_secrets:
req['spec']['imagePullSecrets'] = [{
'name': pull_secret
} for pull_secret in pod.image_pull_secrets.split(',')]

Просмотреть файл

@ -1,28 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# -*- coding: utf-8 -*-
#
# http://www.apache.org/licenses/LICENSE-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import kubernetes_request_factory as kreq
import yaml
from airflow.contrib.kubernetes.kubernetes_request_factory.kubernetes_request_factory \
import KubernetesRequestFactory
from airflow import AirflowException
class SimplePodRequestFactory(KubernetesRequestFactory):
class SimplePodRequestFactory(kreq.KubernetesRequestFactory):
"""
Request generator for a simple pod.
Request generator for a simple pod.
"""
_yaml = """apiVersion: v1
kind: Pod
@ -33,6 +29,9 @@ spec:
- name: base
image: airflow-slave:latest
command: ["/usr/local/airflow/entrypoint.sh", "/bin/bash sleep 25"]
volumeMounts:
- name: shared-data
mountPath: "/usr/local/airflow/dags"
restartPolicy: Never
"""
@ -40,21 +39,48 @@ spec:
pass
def create(self, pod):
# type: (Pod) -> dict
req = yaml.load(self._yaml)
self.extract_name(pod, req)
self.extract_labels(pod, req)
self.extract_image(pod, req)
self.extract_image_pull_policy(pod, req)
self.extract_cmds(pod, req)
self.extract_args(pod, req)
self.extract_node_selector(pod, req)
self.extract_env_and_secrets(pod, req)
self.extract_volume_secrets(pod, req)
self.attach_volumes(pod, req)
self.attach_volume_mounts(pod, req)
self.extract_resources(pod, req)
self.extract_service_account_name(pod, req)
self.extract_init_containers(pod, req)
self.extract_image_pull_secrets(pod, req)
kreq.extract_name(pod, req)
kreq.extract_labels(pod, req)
kreq.extract_image(pod, req)
kreq.extract_cmds(pod, req)
if len(pod.node_selectors) > 0:
kreq.extract_node_selector(pod, req)
kreq.extract_secrets(pod, req)
kreq.extract_volume_secrets(pod, req)
kreq.attach_volume_mounts(req)
return req
class ReturnValuePodRequestFactory(SimplePodRequestFactory):
"""
Pod request factory with a PreStop hook to upload return value
to the system's etcd service.
:param kube_com_service_factory: Kubernetes Communication Service factory
:type kube_com_service_factory: () => KubernetesCommunicationService
"""
def __init__(self, kube_com_service_factory, result_data_file):
super(ReturnValuePodRequestFactory, self).__init__()
self._kube_com_service_factory = kube_com_service_factory
self._result_data_file = result_data_file
def after_create(self, body, pod):
"""
Augment the pod with hyper-parameterized specific logic
Adds a Kubernetes PreStop hook to upload the model training
metrics to the Kubernetes communication engine (probably
an etcd service running with airflow)
"""
container = body['spec']['containers'][0]
pre_stop_hook = self._kube_com_service_factory() \
.pod_pre_stop_hook(self._result_data_file, pod.name)
# Pre-stop hook only works on containers that are deleted. If the container
# naturally exists there would be no pre-stop hook execution. Therefore we
# simulate the hook by wrapping the exe command inside a script
if "'" in ' '.join(container['command']):
raise AirflowException('Please do not include single quote '
'in your command for hyperparameterized pods')
cmd = ' '.join(["'" + c + "'" if " " in c else c for c in container['command']])
container['command'] = ['/bin/bash', '-c', "({}) ; ({})"
.format(cmd, pre_stop_hook)]

Просмотреть файл

@ -1,92 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# -*- coding: utf-8 -*-
#
# http://www.apache.org/licenses/LICENSE-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
class Resources:
def __init__(
self,
request_memory=None,
request_cpu=None,
limit_memory=None,
limit_cpu=None):
self.request_memory = request_memory
self.request_cpu = request_cpu
self.limit_memory = limit_memory
self.limit_cpu = limit_cpu
def is_empty_resource_request(self):
return not self.has_limits() and not self.has_requests()
def has_limits(self):
return self.limit_cpu is not None or self.limit_memory is not None
def has_requests(self):
return self.request_cpu is not None or self.request_memory is not None
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from kubernetes import client, config
from kubernetes_request_factory import KubernetesRequestFactory, SimplePodRequestFactory
import logging
from airflow import AirflowException
import time
import json
class Pod:
"""
Represents a kubernetes pod and manages execution of a single pod.
:param image: The docker image
:type image: str
:param env: A dict containing the environment variables
:type env: dict
:param cmds: The command to be run on the pod
:type cmd: list str
:param secrets: Secrets to be launched to the pod
:type secrets: list Secret
:param result: The result that will be returned to the operator after
successful execution of the pod
:type result: any
Represents a kubernetes pod and manages execution of a single pod.
:param image: The docker image
:type image: str
:param env: A dict containing the environment variables
:type env: dict
:param cmds: The command to be run on the pod
:type cmd: list str
:param secrets: Secrets to be launched to the pod
:type secrets: list Secret
:param result: The result that will be returned to the operator after
successful execution of the pod
:type result: any
"""
pod_timeout = 3600
def __init__(
self,
image,
envs,
cmds,
args=None,
secrets=None,
labels=None,
node_selectors=None,
name=None,
volumes=None,
volume_mounts=None,
secrets,
labels,
node_selectors,
kube_req_factory,
name,
namespace='default',
result=None,
image_pull_policy="IfNotPresent",
image_pull_secrets=None,
init_containers=None,
service_account_name=None,
resources=None
):
result=None):
self.image = image
self.envs = envs or {}
self.envs = envs
self.cmds = cmds
self.args = args or []
self.secrets = secrets or []
self.secrets = secrets
self.result = result
self.labels = labels or {}
self.labels = labels
self.name = name
self.volumes = volumes or []
self.volume_mounts = volume_mounts or []
self.node_selectors = node_selectors or []
self.node_selectors = node_selectors
self.kube_req_factory = (kube_req_factory or SimplePodRequestFactory)()
self.namespace = namespace
self.image_pull_policy = image_pull_policy
self.image_pull_secrets = image_pull_secrets
self.init_containers = init_containers
self.service_account_name = service_account_name
self.resources = resources or Resources()
self.logger = logging.getLogger(self.__class__.__name__)
if not isinstance(self.kube_req_factory, KubernetesRequestFactory):
raise AirflowException('`kube_req_factory`'
' should implement KubernetesRequestFactory')
def launch(self):
"""
Launches the pod synchronously and waits for completion.
"""
k8s_beta = self._kube_client()
req = self.kube_req_factory.create(self)
print(json.dumps(req))
resp = k8s_beta.create_namespaced_job(body=req, namespace=self.namespace)
self.logger.info("Job created. status='%s', yaml:\n%s"
% (str(resp.status), str(req)))
while not self._execution_finished():
time.sleep(10)
return self.result
def _kube_client(self):
config.load_incluster_config()
return client.BatchV1Api()
def _execution_finished(self):
k8s_beta = self._kube_client()
resp = k8s_beta.read_namespaced_job_status(self.name, namespace=self.namespace)
self.logger.info('status : ' + str(resp.status))
if resp.status.phase == 'Failed':
raise Exception("Job " + self.name + " failed!")
return resp.status.phase != 'Running'

Просмотреть файл

@ -1,135 +1,145 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# -*- coding: utf-8 -*-
#
# http://www.apache.org/licenses/LICENSE-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import json
import logging
import time
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
from datetime import datetime as dt
from airflow.contrib.kubernetes.kubernetes_request_factory import \
pod_request_factory as pod_fac
from kubernetes import watch
from kubernetes.client.rest import ApiException
from airflow import AirflowException
from requests.exceptions import HTTPError
from .kube_client import get_kube_client
import urllib2
from kubernetes import client, config
from kubernetes_request_factory import KubernetesRequestFactory
from pod import Pod
class PodStatus(object):
PENDING = 'pending'
RUNNING = 'running'
FAILED = 'failed'
SUCCEEDED = 'succeeded'
def kube_client():
config.load_incluster_config()
return client.CoreV1Api()
class PodLauncher(LoggingMixin):
def __init__(self, kube_client=None):
super(PodLauncher, self).__init__()
self._client = kube_client or get_kube_client()
self._watch = watch.Watch()
self.kube_req_factory = pod_fac.SimplePodRequestFactory()
def incluster_namespace():
"""
:return: The incluster namespace.
"""
config.load_incluster_config()
k8s_configuration = config.incluster_config.configuration
encoded_namespace = k8s_configuration.api_key['authorization'].split(' ')[-1]
api_key = str(base64.b64decode(encoded_namespace))
key_with_namespace = [k for k in api_key.split(',') if 'namespace' in k][0]
unformatted_namespace = key_with_namespace.split(':')[-1]
return unformatted_namespace.replace('"', '')
def run_pod_async(self, pod):
req = self.kube_req_factory.create(pod)
self.log.debug('Pod Creation Request: \n{}'.format(json.dumps(req, indent=2)))
class KubernetesLauncher:
"""
This class is responsible for launching objects to Kubernetes.
Extend this class to launch exotic objects.
Before trying to extend this method check if augmenting the request factory
is enough for your use-case
:param kube_object: A pod or anything that represents a Kubernetes object
:type kube_object: Pod
:param request_factory: A factory method to create kubernetes requests.
"""
pod_timeout = 3600
def __init__(self, kube_object, request_factory):
if not isinstance(kube_object, Pod):
raise Exception('`kube_object` must inherit from Pod')
if not isinstance(request_factory, KubernetesRequestFactory):
raise Exception('`request_factory` must inherit from '
'KubernetesRequestFactory')
self.pod = kube_object
self.request_factory = request_factory
def launch(self):
"""
Launches the pod synchronously and waits for completion.
No return value from execution. Will raise an exception if things failed
"""
k8s_beta = kube_client()
req = self.request_factory.create(self)
logging.info(json.dumps(req))
resp = k8s_beta.create_namespaced_pod(body=req, namespace=self.pod.namespace)
logging.info("Job created. status='%s', yaml:\n%s"
% (str(resp.status), str(req)))
for i in range(1, self.pod_timeout):
time.sleep(10)
logging.info('Waiting for success')
if self._execution_finished():
logging.info('Job finished!')
return
raise Exception("Job timed out!")
def _execution_finished(self):
k8s_beta = kube_client()
resp = k8s_beta.read_namespaced_pod_status(
self.pod.name,
namespace=self.pod.namespace)
logging.info('status : ' + str(resp.status))
logging.info('phase : i' + str(resp.status.phase))
if resp.status.phase == 'Failed':
raise Exception("Job " + self.pod.name + " failed!")
return resp.status.phase != 'Running'
class KubernetesCommunicationService:
"""
A service that manages communications between pods in Kubernetes and ariflow dagrun
Note that etcd service is running side by side of the airflow on the same machine
using kubernetes magic, so on airflow side we use localhost, and on the remote side
we use the provided etcd host.
"""
def __init__(self, etcd_host, etcd_port):
self.etcd_host = etcd_host
self.etcd_port = etcd_port
self.url = 'http://localhost:{}'.format(self.etcd_port)
def pod_pre_stop_hook(self, return_data_file, task_id):
return 'echo value=$(cat %s) | curl -d "@-" -X PUT %s:%s/v2/keys/pod_metrics/%s' \
% (
return_data_file, self.etcd_host, self.etcd_port, task_id)
def pod_return_data(self, task_id):
"""
Returns the pod's return data. The pod_pre_stop_hook is responsible to upload
the return data to etcd.
If the return_data_file is generated by the application, the pre stop hook
will upload it to etcd and we will be download it back to airflow.
"""
logging.info('querying {} for task id {}'.format(self.url, task_id))
try:
resp = self._client.create_namespaced_pod(body=req, namespace=pod.namespace)
self.log.debug('Pod Creation Response: {}'.format(resp))
except ApiException:
self.log.exception('Exception when attempting to create Namespaced Pod.')
result = urllib2.urlopen(self.url + '/v2/keys/pod_metrics/' + task_id).read()
logging.info('result for querying {} for task id {}: {}'
.format(self.url, task_id, result))
result = json.loads(result)['node']['value']
return result
except urllib2.HTTPError as err:
if err.code == 404:
return None # Data not found
raise
return resp
def run_pod(self, pod, startup_timeout=120, get_logs=True):
# type: (Pod) -> State
"""
Launches the pod synchronously and waits for completion.
Args:
pod (Pod):
startup_timeout (int): Timeout for startup of the pod (if pod is pending for
too long, considers task a failure
"""
resp = self.run_pod_async(pod)
curr_time = dt.now()
if resp.status.start_time is None:
while self.pod_not_started(pod):
delta = dt.now() - curr_time
if delta.seconds >= startup_timeout:
raise AirflowException("Pod took too long to start")
time.sleep(1)
self.log.debug('Pod not yet started')
final_status = self._monitor_pod(pod, get_logs)
return final_status
def _monitor_pod(self, pod, get_logs):
# type: (Pod) -> State
if get_logs:
logs = self._client.read_namespaced_pod_log(
name=pod.name,
namespace=pod.namespace,
follow=True,
tail_lines=10,
_preload_content=False)
for line in logs:
self.log.info(line)
else:
while self.pod_is_running(pod):
self.log.info("Pod {} has state {}".format(pod.name, State.RUNNING))
time.sleep(2)
return self._task_status(self.read_pod(pod))
def _task_status(self, event):
# type: (V1Pod) -> State
self.log.info(
"Event: {} had an event of type {}".format(event.metadata.name,
event.status.phase))
status = self.process_status(event.metadata.name, event.status.phase)
return status
def pod_not_started(self, pod):
state = self._task_status(self.read_pod(pod))
return state == State.QUEUED
def pod_is_running(self, pod):
state = self._task_status(self.read_pod(pod))
return state != State.SUCCESS and state != State.FAILED
def read_pod(self, pod):
try:
return self._client.read_namespaced_pod(pod.name, pod.namespace)
except HTTPError as e:
raise AirflowException("There was an error reading the kubernetes API: {}"
.format(e))
def process_status(self, job_id, status):
status = status.lower()
if status == PodStatus.PENDING:
return State.QUEUED
elif status == PodStatus.FAILED:
self.log.info("Event: {} Failed".format(job_id))
return State.FAILED
elif status == PodStatus.SUCCEEDED:
self.log.info("Event: {} Succeeded".format(job_id))
return State.SUCCESS
elif status == PodStatus.RUNNING:
return State.RUNNING
else:
self.log.info("Event: Invalid state {} on job {}".format(status, job_id))
return State.FAILED
@staticmethod
def from_dag_default_args(dag):
(etcd_host, etcd_port) = dag.default_args.get('etcd_endpoint', ':').split(':')
logging.info('Setting etcd endpoint from dag default args {}:{}'
.format(etcd_host, etcd_port))
if not etcd_host:
raise Exception('`KubernetesCommunicationService` '
'requires etcd endpoint. Please defined it in dag '
'degault_args')
return KubernetesCommunicationService(etcd_host, etcd_port)

Просмотреть файл

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .k8s_pod_operator import *

Просмотреть файл

@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from airflow.exceptions import AirflowException
from airflow.operators.python_operator import PythonOperator
from airflow.utils.decorators import apply_defaults
from airflow.contrib.kubernetes.pod_launcher import KubernetesLauncher, \
KubernetesCommunicationService
from airflow.contrib.kubernetes.kubernetes_request_factory import \
SimplePodRequestFactory, \
ReturnValuePodRequestFactory
from .op_context import OpContext
class PodOperator(PythonOperator):
"""
Executes a pod and waits for the job to finish.
:param dag_run_id: The unique run ID that would be attached to the pod as a label
:type dag_run_id: str
:param pod_factory: Reference to the function that creates the pod with format:
function (OpContext) => Pod
:type pod_factory: callable
:param cache_output: If set to true, the output of the pod would be saved in a
cache object using md5 hash of all the pod parameters
and in case of success, the cached results will be returned
on consecutive calls. Only use this
"""
# template_fields = tuple('dag_run_id')
ui_color = '#8da7be'
@apply_defaults
def __init__(
self,
dag_run_id,
pod_factory,
cache_output,
kube_request_factory=None,
*args,
**kwargs
):
super(PodOperator, self).__init__(
python_callable=lambda _: 1,
provide_context=True,
*args,
**kwargs)
self.logger = logging.getLogger(self.__class__.__name__)
if not callable(pod_factory):
raise AirflowException('`pod_factory` param must be callable')
self.dag_run_id = dag_run_id
self.pod_factory = pod_factory
self._cache_output = cache_output
self.op_context = OpContext(self.task_id)
self.kwargs = kwargs
self._kube_request_factory = kube_request_factory or SimplePodRequestFactory
def execute(self, context):
task_instance = context.get('task_instance')
if task_instance is None:
raise AirflowException('`task_instance` is empty! This should not happen')
self.op_context.set_xcom_instance(task_instance)
pod = self.pod_factory(self.op_context, context)
# Customize the pod
pod.name = self.task_id
pod.labels['run_id'] = self.dag_run_id
pod.namespace = self.dag.default_args.get('namespace', pod.namespace)
# Launch the pod and wait for it to finish
KubernetesLauncher(pod, self._kube_request_factory).launch()
self.op_context.result = pod.result
# Cache the output
custom_return_value = self.on_pod_success(context)
if custom_return_value:
self.op_context.custom_return_value = custom_return_value
return self.op_context.result
def on_pod_success(self, context):
"""
Called when pod is executed successfully.
:return: Returns a custom return value for pod which will
be stored in xcom
"""
pass
class ReturnValuePodOperator(PodOperator):
"""
This pod operators is a normal pod operator with the addition of
reading custom return value back from kubernetes.
"""
def __init__(self,
kube_com_service_factory,
result_data_file,
*args, **kwargs):
super(ReturnValuePodOperator, self).__init__(*args, **kwargs)
if not isinstance(kube_com_service_factory(), KubernetesCommunicationService):
raise AirflowException(
'`kube_com_service_factory` must be of type '
'KubernetesCommunicationService')
self._kube_com_service_factory = kube_com_service_factory
self._result_data_file = result_data_file
self._kube_request_factory = self._return_value_kube_request # Overwrite the
# default request factory
def on_pod_success(self, context):
return_val = self._kube_com_service_factory().pod_return_data(self.task_id)
self.op_context.result = return_val # We also overwrite the results
return return_val
def _return_value_kube_request(self):
return ReturnValuePodRequestFactory(self._kube_com_service_factory,
self._result_data_file)

Просмотреть файл

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from airflow import AirflowException
import logging
class OpContext(object):
"""
Data model for operation context of a pod operator with hyper parameters.
OpContext is able to communicate the context between PodOperators by
encapsulating XCom communication
Note: do not directly modify the upstreams
Also note: xcom_instance MUST be set before any attribute of this class can be
read.
:param: task_id The task ID
"""
_supported_attributes = {'hyper_parameters', 'custom_return_value'}
def __init__(self, task_id):
self.task_id = task_id
self._upstream = []
self._result = '__not_set__'
self._data = {}
self._xcom_instance = None
self._parent = None
def __str__(self):
return 'upstream: [' + \
','.join([u.task_id for u in self._upstream]) + ']\n' + \
'params:' + ','.join(
[k + '=' + str(self._data[k]) for k in self._data.keys()])
def __setattr__(self, name, value):
if name in self._data:
raise AirflowException('`{}` is already set'.format(name))
if name not in self._supported_attributes:
logging.warn(
'`{}` is not in the supported attribute list for OpContext'.format(name))
self.get_xcom_instance().xcom_push(key=name, value=value)
self._data[name] = value
def __getattr__(self, item):
if item not in self._supported_attributes:
logging.warn(
'`{}` is not in the supported attribute list for OpContext'.format(item))
if item not in self._data:
self._data[item] = self.get_xcom_instance().xcom_pull(key=item,
task_ids=self.task_id)
return self._data[item]
@property
def result(self):
if self._result == '__not_set__':
self._result = self.get_xcom_instance().xcom_pull(task_ids=self.task_id)
return self._result
@result.setter
def result(self, value):
if self._result != '__not_set__':
raise AirflowException('`result` is already set')
self._result = value
@property
def upstream(self):
return self._upstream
def append_upstream(self, upstream_op_contexes):
"""
Appends a list of op_contexts to the upstream. It will create new instances and
set the task_id.
All the upstream op_contextes will share the same xcom_instance with this
op_context
:param upstream_op_contexes: List of upstream op_contextes
"""
for up in upstream_op_contexes:
op_context = OpContext(up.tak_id)
op_context._parent = self
self._upstream.append(op_context)
def set_xcom_instance(self, xcom_instance):
"""
Sets the xcom_instance for this op_context and upstreams
:param xcom_instance: The Airflow TaskInstance for communication through XCom
:type xcom_instance: airflow.models.TaskInstance
"""
self._xcom_instance = xcom_instance
def get_xcom_instance(self):
if self._xcom_instance is None and self._parent is None:
raise AirflowException(
'Trying to access attribtues from OpContext before setting the '
'xcom_instance')
return self._xcom_instance or self._parent.get_xcom_instance()

Просмотреть файл

@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from airflow import configuration
def _integrate_plugins():
pass
dag_import_spec = {}
def import_dags():
logging.info("importing dags")
if configuration.has_option('core', 'k8s_mode'):
mode = configuration.get('core', 'k8s_mode')
dag_import_func(mode)()
else:
_import_hostpath()
def dag_import_func(mode):
return {
'git': _import_git,
'cinder': _import_cinder,
}.get(mode, _import_hostpath)
def _import_hostpath():
logging.info("importing dags locally")
spec = {'name': 'shared-data', 'hostPath': {}}
spec['hostPath']['path'] = '/tmp/dags'
global dag_import_spec
dag_import_spec = spec
def _import_cinder():
'''
kind: StorageClass
apiVersion: storage.k8s.io/v1
metadata:
name: gold
provisioner: kubernetes.io/cinder
parameters:
type: fast
availability: nova
:return:
'''
global dag_import_spec
spec = {}
spec['kind'] = 'StorageClass'
spec['apiVersion'] = 'storage.k8s.io/v1'
spec['metatdata']['name'] = 'gold'
spec['provisioner'] = 'kubernetes.io/cinder'
spec['parameters']['type'] = 'fast'
spec['availability'] = 'nova'
def _import_git():
logging.info("importing dags from github")
global dag_import_spec
git_link = configuration.get('core', 'k8s_git_link')
spec = {'name': 'shared-data', 'gitRepo': {}}
spec['gitRepo']['repository'] = git_link
if configuration.has_option('core','k8s_git_revision'):
revision = configuration.get('core', 'k8s_git_revision')
spec['gitRepo']['revision'] = revision
dag_import_spec = spec

Просмотреть файл

@ -19,13 +19,11 @@
import sys
from airflow import configuration
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.executors.local_executor import LocalExecutor
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.exceptions import AirflowException
from airflow.utils.log.logging_mixin import LoggingMixin
DEFAULT_EXECUTOR = None
def _integrate_plugins():
@ -52,6 +50,8 @@ def GetDefaultExecutor():
return DEFAULT_EXECUTOR
def _get_executor(executor_name):
"""
Creates a new instance of the named executor. In case the executor name is not know in airflow,
@ -70,6 +70,9 @@ def _get_executor(executor_name):
elif executor_name == 'MesosExecutor':
from airflow.contrib.executors.mesos_executor import MesosExecutor
return MesosExecutor()
elif executor_name == 'KubernetesExecutor':
from airflow.contrib.executors.kubernetes_executor import KubernetesExecutor
return KubernetesExecutor()
else:
# Loading plugins
_integrate_plugins()

Просмотреть файл

@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@ -21,7 +21,6 @@ from builtins import range
from airflow import configuration
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
PARALLELISM = configuration.conf.getint('core', 'PARALLELISM')
@ -40,6 +39,7 @@ class BaseExecutor(LoggingMixin):
self.queued_tasks = {}
self.running = {}
self.event_buffer = {}
self.logger.setLevel(10)
def start(self): # pragma: no cover
"""
@ -53,6 +53,8 @@ class BaseExecutor(LoggingMixin):
if key not in self.queued_tasks and key not in self.running:
self.log.info("Adding to queue: %s", command)
self.queued_tasks[key] = (command, priority, queue, task_instance)
else:
self.logger.info("could not queue task {}".format(key))
def queue_task_instance(
self,
@ -104,8 +106,7 @@ class BaseExecutor(LoggingMixin):
"""
pass
def heartbeat(self):
def heartbeat(self, km=False):
# Triggering new jobs
if not self.parallelism:
open_slots = len(self.queued_tasks)
@ -131,14 +132,13 @@ class BaseExecutor(LoggingMixin):
# does NOT eliminate it.
self.queued_tasks.pop(key)
ti.refresh_from_db()
if ti.state != State.RUNNING:
if ti.state != State.RUNNING or km:
self.running[key] = command
self.execute_async(key, command=command, queue=queue)
else:
self.log.debug(
'Task is already running, not sending to executor: %s',
key
)
self.logger.info(
'Task is already running, not sending to '
'executor: {}'.format(key))
# Calling child class sync method
self.log.debug("Calling the %s sync method", self.__class__)

Просмотреть файл

@ -1117,6 +1117,28 @@ class TaskInstance(Base, LoggingMixin):
session.merge(self)
session.commit()
@provide_session
def update_hostname(self, hostname, session=None):
"""
For use in kubernetes mode. Update the session to allow heartbeating to SQL
:param session:
:return:
"""
t_i = TaskInstance
qry = session.query(t_i).filter(
t_i.dag_id == self.dag_id,
t_i.task_id == self.task_id,
t_i.execution_date == self.execution_date)
ti = qry.first()
if ti:
ti.hostname = hostname
session.add(ti)
session.commit()
@provide_session
def refresh_from_db(self, session=None, lock_for_update=False):
"""

Просмотреть файл

@ -48,6 +48,7 @@ class AirflowPlugin(object):
admin_views = []
flask_blueprints = []
menu_links = []
dag_importer = None
@classmethod
def validate(cls):

0
kubectl Normal file
Просмотреть файл

Двоичные данные
scripts/ci/kubernetes/docker/airflow.tar.gz Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,195 @@
# Licensed to the Apache Software Foundation (ASF) under one *
# or more contributor license agreements. See the NOTICE file *
# distributed with this work for additional information *
# regarding copyright ownership. The ASF licenses this file *
# to you under the Apache License, Version 2.0 (the *
# "License"); you may not use this file except in compliance *
# with the License. You may obtain a copy of the License at *
# *
# http://www.apache.org/licenses/LICENSE-2.0 *
# *
# Unless required by applicable law or agreed to in writing, *
# software distributed under the License is distributed on an *
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY *
# KIND, either express or implied. See the License for the *
# specific language governing permissions and limitations *
# under the License. *
# The backing volume can be anything you want, it just needs to be `ReadWriteOnce`
# I'm using hostPath since minikube is nice for testing, but any (non-local) volume will work on a real cluster
kind: PersistentVolume
apiVersion: v1
metadata:
name: airflow-dags
labels:
type: local
spec:
capacity:
storage: 10Gi
accessModes:
- ReadWriteOnce
hostPath:
path: "/data/airflow-dags"
---
kind: PersistentVolumeClaim
apiVersion: v1
metadata:
name: airflow-dags
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 10Gi
---
apiVersion: extensions/v1beta1
kind: Deployment
metadata:
name: airflow
spec:
replicas: 1
template:
metadata:
labels:
name: airflow
spec:
initContainers:
- name: "init"
image: "airflow/ci:latest"
imagePullPolicy: "IfNotPresent"
volumeMounts:
- name: airflow-configmap
mountPath: /root/airflow/airflow.cfg
subPath: airflow.cfg
- name: airflow-dags
mountPath: /root/airflow/dags
env:
- name: SQL_ALCHEMY_CONN
valueFrom:
secretKeyRef:
name: airflow-secrets
key: sql_alchemy_conn
command:
- "bash"
args:
- "-cx"
- "cd /usr/local/lib/python2.7/dist-packages/airflow && cp -R example_dags/* /root/airflow/dags/ && airflow initdb && alembic upgrade heads"
containers:
- name: web
image: airflow/ci:latest
imagePullPolicy: IfNotPresent
ports:
- name: web
containerPort: 8080
args: ["webserver"]
env:
- name: AIRFLOW_KUBE_NAMESPACE
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: SQL_ALCHEMY_CONN
valueFrom:
secretKeyRef:
name: airflow-secrets
key: sql_alchemy_conn
volumeMounts:
- name: airflow-configmap
mountPath: /root/airflow/airflow.cfg
subPath: airflow.cfg
- name: airflow-dags
mountPath: /root/airflow/dags
readinessProbe:
initialDelaySeconds: 5
timeoutSeconds: 5
periodSeconds: 5
httpGet:
path: /admin
port: 8080
livenessProbe:
initialDelaySeconds: 5
timeoutSeconds: 5
failureThreshold: 5
httpGet:
path: /admin
port: 8080
- name: scheduler
image: airflow/ci:latest
imagePullPolicy: IfNotPresent
args: ["scheduler"]
env:
- name: AIRFLOW_KUBE_NAMESPACE
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: SQL_ALCHEMY_CONN
valueFrom:
secretKeyRef:
name: airflow-secrets
key: sql_alchemy_conn
volumeMounts:
- name: airflow-configmap
mountPath: /root/airflow/airflow.cfg
subPath: airflow.cfg
- name: airflow-dags
mountPath: /root/airflow/dags
volumes:
- name: airflow-dags
persistentVolumeClaim:
claimName: airflow-dags
- name: airflow-configmap
configMap:
name: airflow-configmap
---
apiVersion: v1
kind: Service
metadata:
name: airflow
spec:
type: NodePort
ports:
- port: 8080
nodePort: 30809
selector:
name: airflow
---
apiVersion: v1
kind: Secret
metadata:
name: airflow-secrets
type: Opaque
data:
# The sql_alchemy_conn value is a base64 encoded represenation of this connection string:
# postgresql+psycopg2://root:root@postgres-airflow:5432/airflow
sql_alchemy_conn: cG9zdGdyZXNxbCtwc3ljb3BnMjovL3Jvb3Q6cm9vdEBwb3N0Z3Jlcy1haXJmbG93OjU0MzIvYWlyZmxvdwo=
---
apiVersion: v1
kind: ConfigMap
metadata:
name: airflow-configmap
data:
airflow.cfg: |
[core]
airflow_home = /root/airflow
dags_folder = /root/airflow/dags
base_log_folder = /root/airflow/logs
logging_level = INFO
executor = KubernetesExecutor
parallelism = 32
plugins_folder = /root/airflow/plugins
sql_alchemy_conn = $SQL_ALCHEMY_CONN
[scheduler]
dag_dir_list_interval = 60
child_process_log_directory = /root/airflow/logs/scheduler
[kubernetes]
airflow_configmap = airflow-configmap
worker_container_repository = airflow/ci
worker_container_tag = latest
delete_worker_pods = False
git_repo = https://github.com/grantnicholas/testdags.git
git_branch = master
dags_volume_claim = airflow-dags
[kubernetes_secrets]
SQL_ALCHEMY_CONN = airflow-secrets=sql_alchemy_conn

Просмотреть файл

@ -0,0 +1,96 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
alembic
azure-storage>=0.34.0
bcrypt
bleach
boto
boto3
celery
cgroupspy
chartkick
cloudant
coverage
coveralls
croniter>=0.3.17
cryptography
datadog
dill
distributed
docker-py
filechunkio
flake8
flask
flask-admin
flask-bcrypt
flask-cache
flask-login==0.2.11
Flask-WTF
flower
freezegun
future
google-api-python-client>=1.5.0,<1.6.0
gunicorn
hdfs
hive-thrift-py
impyla
ipython
jaydebeapi
jinja2<2.9.0
jira
ldap3
lxml
markdown
mock
moto==1.1.19
mysqlclient
nose
nose-exclude
nose-ignore-docstring==0.2
nose-timer
oauth2client>=2.0.2,<2.1.0
pandas
pandas-gbq
parameterized
paramiko>=2.1.1
pendulum>=1.3.2
psutil>=4.2.0, <5.0.0
psycopg2
pygments
pyhive
pykerberos
PyOpenSSL
PySmbClient
python-daemon
python-dateutil
qds-sdk>=1.9.6
redis
rednose
requests
requests-kerberos
requests_mock
sendgrid
setproctitle
slackclient
sphinx
sphinx-argparse
Sphinx-PyPI-upload
sphinx_rtd_theme
sqlalchemy>=1.1.15, <1.2.0
statsd
thrift
thrift_sasl
unicodecsv
zdesk
kubernetes

Просмотреть файл

@ -20,4 +20,4 @@
from __future__ import absolute_import
from .operators import *
from .sensors import *
from .utils import *
from .kubernetes import *

Просмотреть файл

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

Просмотреть файл

@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

Просмотреть файл

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import unittest
from airflow.contrib.kubernetes.kubernetes_job_builder import KubernetesJobBuilder
from airflow.contrib.kubernetes.kubernetes_request_factory import SimpleJobRequestFactory
from airflow import configuration
import json
secrets = {}
labels = {}
base_job = {'kind': 'Job',
'spec': {
'template': {
'spec': {
'restartPolicy': 'Never',
'volumes': [{'hostPath': {'path': '/tmp/dags'}, 'name': 'shared-data'}],
'containers': [
{'command': ['try', 'this', 'first'],
'image': 'foo.image', 'volumeMounts': [
{
'mountPath': '/usr/local/airflow/dags',
'name': 'shared-data'}
],
'name': 'base',
'imagePullPolicy': 'Never'}
]
},
'metadata': {'name': 'name'}
}
},
'apiVersion': 'batch/v1', 'metadata': {'name': None}
}
class KubernetesJobRequestTest(unittest.TestCase):
job_to_load = None
job_req_factory = SimpleJobRequestFactory()
def setUp(self):
configuration.load_test_config()
self.job_to_load = KubernetesJobBuilder(
image='foo.image',
cmds=['try', 'this', 'first']
)
def test_job_creation_with_base_values(self):
base_job_result = self.job_req_factory.create(self.job_to_load)
self.assertEqual(base_job_result, base_job)