Merge pull request #96 from mistercrunch/async_ti_job
New RunTaskJob that runs async, implements kill signal
This commit is contained in:
Коммит
689853a797
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from airflow.configuration import conf
|
||||
from airflow.executors import DEFAULT_EXECUTOR
|
||||
from airflow import settings
|
||||
from airflow import utils
|
||||
from airflow import jobs
|
||||
|
@ -10,7 +11,6 @@ import dateutil.parser
|
|||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import argparse
|
||||
|
@ -67,7 +67,6 @@ def run(args):
|
|||
logging.basicConfig(
|
||||
filename=filename, level=logging.INFO,
|
||||
format=settings.LOG_FORMAT)
|
||||
print("Logging into: " + filename)
|
||||
|
||||
if not args.pickle:
|
||||
dagbag = DagBag(args.subdir)
|
||||
|
@ -85,17 +84,19 @@ def run(args):
|
|||
# TODO: add run_local and fire it with the right executor from run
|
||||
ti = TaskInstance(task, args.execution_date)
|
||||
|
||||
# This is enough to fail the task instance
|
||||
def signal_handler(signum, frame):
|
||||
logging.error("SIGINT (ctrl-c) received".format(args.task_id))
|
||||
ti.error(args.execution_date)
|
||||
sys.exit()
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
ti.run(
|
||||
if args.local:
|
||||
print("Logging into: " + filename)
|
||||
run_job = jobs.LocalTaskJob(
|
||||
task_instance=ti,
|
||||
mark_success=args.mark_success,
|
||||
force=args.force,
|
||||
ignore_dependencies=args.ignore_dependencies)
|
||||
run_job.run()
|
||||
else:
|
||||
executor = DEFAULT_EXECUTOR
|
||||
executor.start()
|
||||
executor.queue_command(ti.key, ti.command())
|
||||
executor.end()
|
||||
|
||||
|
||||
def list_dags(args):
|
||||
|
@ -166,7 +167,6 @@ def webserver(args):
|
|||
host=args.hostname, port=args.port))
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado import log
|
||||
from tornado.wsgi import WSGIContainer
|
||||
|
||||
# simple multi-process server
|
||||
|
@ -349,6 +349,10 @@ if __name__ == '__main__':
|
|||
"-f", "--force",
|
||||
help="Force a run regardless or previous success",
|
||||
action="store_true")
|
||||
parser_run.add_argument(
|
||||
"-l", "--local",
|
||||
help="Runs the task locally, don't use the executor",
|
||||
action="store_true")
|
||||
parser_run.add_argument(
|
||||
"-i", "--ignore_dependencies",
|
||||
help="Ignore upstream and depends_on_past dependencies",
|
||||
|
|
|
@ -23,7 +23,8 @@ class SequentialExecutor(BaseExecutor):
|
|||
def heartbeat(self):
|
||||
for key, command in self.commands_to_run:
|
||||
try:
|
||||
sp = subprocess.Popen(command, shell=True).wait()
|
||||
sp = subprocess.Popen(command, shell=True)
|
||||
sp.wait()
|
||||
except Exception as e:
|
||||
self.change_state(key, State.FAILED)
|
||||
raise e
|
||||
|
@ -31,4 +32,4 @@ class SequentialExecutor(BaseExecutor):
|
|||
self.commands_to_run = []
|
||||
|
||||
def end(self):
|
||||
pass
|
||||
self.heartbeat()
|
||||
|
|
|
@ -3,10 +3,11 @@ import getpass
|
|||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from time import sleep
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, DateTime, ForeignKey)
|
||||
Column, Integer, String, DateTime)
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm.session import make_transient
|
||||
|
||||
|
@ -57,7 +58,7 @@ class BaseJob(Base):
|
|||
self.executor = executor
|
||||
self.executor_class = executor.__class__.__name__
|
||||
self.start_date = datetime.now()
|
||||
self.latest_heartbeat = None
|
||||
self.latest_heartbeat = datetime.now()
|
||||
self.heartrate = heartrate
|
||||
self.unixname = getpass.getuser()
|
||||
super(BaseJob, self).__init__(*args, **kwargs)
|
||||
|
@ -68,6 +69,25 @@ class BaseJob(Base):
|
|||
(conf.getint('misc', 'JOB_HEARTBEAT_SEC') * 2.1)
|
||||
)
|
||||
|
||||
def kill(self):
|
||||
session = settings.Session()
|
||||
job = session.query(BaseJob).filter(BaseJob.id==self.id).first()
|
||||
job.state = State.FAILED
|
||||
job.end_date = datetime.now()
|
||||
try:
|
||||
self.on_kill()
|
||||
except:
|
||||
logging.error('on_kill() method failed')
|
||||
session.merge(job)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def on_kill(self):
|
||||
'''
|
||||
Will be called when an external kill command is received from the db
|
||||
'''
|
||||
pass
|
||||
|
||||
def heartbeat(self):
|
||||
'''
|
||||
Heartbeats update the job's entry in the the database with a timestamp
|
||||
|
@ -87,14 +107,12 @@ class BaseJob(Base):
|
|||
heart rate. If you go over 60 seconds before calling it, it won't
|
||||
sleep at all.
|
||||
'''
|
||||
failed = False
|
||||
session = settings.Session()
|
||||
job = session.query(BaseJob).filter(BaseJob.id==self.id).first()
|
||||
|
||||
if job.state == State.SHUTDOWN:
|
||||
job.state = State.FAILED
|
||||
job.end_date = datetime.now()
|
||||
failed = True
|
||||
self.kill()
|
||||
raise Exception("Task shut down externally")
|
||||
|
||||
if job.latest_heartbeat:
|
||||
sleep_for = self.heartrate - (
|
||||
|
@ -107,9 +125,7 @@ class BaseJob(Base):
|
|||
session.merge(job)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
if failed:
|
||||
raise Exception("Task shut down externally")
|
||||
logging.info('[heart] Boom.')
|
||||
|
||||
def run(self):
|
||||
# Adding an entry in the DB
|
||||
|
@ -122,6 +138,13 @@ class BaseJob(Base):
|
|||
self.id = id_
|
||||
|
||||
# Run!
|
||||
# This is enough to fail the task instance
|
||||
def signal_handler(signum, frame):
|
||||
logging.error("SIGINT (ctrl-c) received")
|
||||
self.kill()
|
||||
sys.exit()
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
self._execute()
|
||||
|
||||
# Marking the success in the DB
|
||||
|
@ -338,5 +361,43 @@ class BackfillJob(BaseJob):
|
|||
succeeded.append(key)
|
||||
del tasks_to_run[key]
|
||||
executor.end()
|
||||
logging.info("Run summary:")
|
||||
session.close()
|
||||
|
||||
|
||||
class LocalTaskJob(BaseJob):
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'LocalTaskJob'
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_instance,
|
||||
ignore_dependencies=False,
|
||||
force=False,
|
||||
mark_success=False,
|
||||
*args, **kwargs):
|
||||
self.task_instance = task_instance
|
||||
self.ignore_dependencies = ignore_dependencies
|
||||
self.force = force
|
||||
self.mark_success = mark_success
|
||||
super(LocalTaskJob, self).__init__(*args, **kwargs)
|
||||
|
||||
def _execute(self):
|
||||
|
||||
thr = threading.Thread(
|
||||
target=self.task_instance.run,
|
||||
kwargs={
|
||||
'ignore_dependencies': self.ignore_dependencies,
|
||||
'force': self.force,
|
||||
'mark_success': self.mark_success,
|
||||
})
|
||||
self.thr = thr
|
||||
|
||||
thr.start()
|
||||
while thr.is_alive():
|
||||
self.heartbeat()
|
||||
|
||||
def on_kill(self):
|
||||
self.task_instance.error(self.task_instance.execution_date)
|
||||
self.thr.kill()
|
||||
|
|
|
@ -256,18 +256,19 @@ class TaskInstance(Base):
|
|||
mark_success=False,
|
||||
ignore_dependencies=False,
|
||||
force=False,
|
||||
local=True,
|
||||
pickle_id=None):
|
||||
"""
|
||||
Returns a command that can be executed anywhere where airflow is
|
||||
installed. This command is part of the message sent to executors by
|
||||
the orchestrator.
|
||||
"""
|
||||
|
||||
iso = self.execution_date.isoformat()
|
||||
mark_success = "--mark_success" if mark_success else ""
|
||||
pickle = "--pickle {0}".format(pickle_id) if pickle_id else ""
|
||||
ignore_dependencies = "-i" if ignore_dependencies else ""
|
||||
force = "--force" if force else ""
|
||||
local = "--local" if local else ""
|
||||
subdir = ""
|
||||
if not pickle and self.task.dag and self.task.dag.full_filepath:
|
||||
subdir = "-sd {0}".format(self.task.dag.full_filepath)
|
||||
|
@ -276,6 +277,7 @@ class TaskInstance(Base):
|
|||
"{self.dag_id} {self.task_id} {iso} "
|
||||
"{mark_success} "
|
||||
"{pickle} "
|
||||
"{local} "
|
||||
"{ignore_dependencies} "
|
||||
"{force} "
|
||||
"{subdir} "
|
||||
|
|
1
setup.py
1
setup.py
|
@ -27,6 +27,7 @@ setup(
|
|||
'mysql-python>=1.2.5',
|
||||
'pandas>=0.15.2',
|
||||
'pygments>=2.0.1',
|
||||
'pysmbclient>=0.1.3'
|
||||
'pyhive>=0.1.3',
|
||||
'python-dateutil>=2.3',
|
||||
'requests>=2.5.1',
|
||||
|
|
Загрузка…
Ссылка в новой задаче