Merge pull request #96 from mistercrunch/async_ti_job
New RunTaskJob that runs async, implements kill signal
This commit is contained in:
Коммит
689853a797
|
@ -1,6 +1,7 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
from airflow.configuration import conf
|
from airflow.configuration import conf
|
||||||
|
from airflow.executors import DEFAULT_EXECUTOR
|
||||||
from airflow import settings
|
from airflow import settings
|
||||||
from airflow import utils
|
from airflow import utils
|
||||||
from airflow import jobs
|
from airflow import jobs
|
||||||
|
@ -10,7 +11,6 @@ import dateutil.parser
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -67,7 +67,6 @@ def run(args):
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
filename=filename, level=logging.INFO,
|
filename=filename, level=logging.INFO,
|
||||||
format=settings.LOG_FORMAT)
|
format=settings.LOG_FORMAT)
|
||||||
print("Logging into: " + filename)
|
|
||||||
|
|
||||||
if not args.pickle:
|
if not args.pickle:
|
||||||
dagbag = DagBag(args.subdir)
|
dagbag = DagBag(args.subdir)
|
||||||
|
@ -85,17 +84,19 @@ def run(args):
|
||||||
# TODO: add run_local and fire it with the right executor from run
|
# TODO: add run_local and fire it with the right executor from run
|
||||||
ti = TaskInstance(task, args.execution_date)
|
ti = TaskInstance(task, args.execution_date)
|
||||||
|
|
||||||
# This is enough to fail the task instance
|
if args.local:
|
||||||
def signal_handler(signum, frame):
|
print("Logging into: " + filename)
|
||||||
logging.error("SIGINT (ctrl-c) received".format(args.task_id))
|
run_job = jobs.LocalTaskJob(
|
||||||
ti.error(args.execution_date)
|
task_instance=ti,
|
||||||
sys.exit()
|
mark_success=args.mark_success,
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
force=args.force,
|
||||||
|
ignore_dependencies=args.ignore_dependencies)
|
||||||
ti.run(
|
run_job.run()
|
||||||
mark_success=args.mark_success,
|
else:
|
||||||
force=args.force,
|
executor = DEFAULT_EXECUTOR
|
||||||
ignore_dependencies=args.ignore_dependencies)
|
executor.start()
|
||||||
|
executor.queue_command(ti.key, ti.command())
|
||||||
|
executor.end()
|
||||||
|
|
||||||
|
|
||||||
def list_dags(args):
|
def list_dags(args):
|
||||||
|
@ -166,7 +167,6 @@ def webserver(args):
|
||||||
host=args.hostname, port=args.port))
|
host=args.hostname, port=args.port))
|
||||||
from tornado.httpserver import HTTPServer
|
from tornado.httpserver import HTTPServer
|
||||||
from tornado.ioloop import IOLoop
|
from tornado.ioloop import IOLoop
|
||||||
from tornado import log
|
|
||||||
from tornado.wsgi import WSGIContainer
|
from tornado.wsgi import WSGIContainer
|
||||||
|
|
||||||
# simple multi-process server
|
# simple multi-process server
|
||||||
|
@ -349,6 +349,10 @@ if __name__ == '__main__':
|
||||||
"-f", "--force",
|
"-f", "--force",
|
||||||
help="Force a run regardless or previous success",
|
help="Force a run regardless or previous success",
|
||||||
action="store_true")
|
action="store_true")
|
||||||
|
parser_run.add_argument(
|
||||||
|
"-l", "--local",
|
||||||
|
help="Runs the task locally, don't use the executor",
|
||||||
|
action="store_true")
|
||||||
parser_run.add_argument(
|
parser_run.add_argument(
|
||||||
"-i", "--ignore_dependencies",
|
"-i", "--ignore_dependencies",
|
||||||
help="Ignore upstream and depends_on_past dependencies",
|
help="Ignore upstream and depends_on_past dependencies",
|
||||||
|
|
|
@ -23,7 +23,8 @@ class SequentialExecutor(BaseExecutor):
|
||||||
def heartbeat(self):
|
def heartbeat(self):
|
||||||
for key, command in self.commands_to_run:
|
for key, command in self.commands_to_run:
|
||||||
try:
|
try:
|
||||||
sp = subprocess.Popen(command, shell=True).wait()
|
sp = subprocess.Popen(command, shell=True)
|
||||||
|
sp.wait()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.change_state(key, State.FAILED)
|
self.change_state(key, State.FAILED)
|
||||||
raise e
|
raise e
|
||||||
|
@ -31,4 +32,4 @@ class SequentialExecutor(BaseExecutor):
|
||||||
self.commands_to_run = []
|
self.commands_to_run = []
|
||||||
|
|
||||||
def end(self):
|
def end(self):
|
||||||
pass
|
self.heartbeat()
|
||||||
|
|
|
@ -3,10 +3,11 @@ import getpass
|
||||||
import logging
|
import logging
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column, Integer, String, DateTime, ForeignKey)
|
Column, Integer, String, DateTime)
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm.session import make_transient
|
from sqlalchemy.orm.session import make_transient
|
||||||
|
|
||||||
|
@ -57,7 +58,7 @@ class BaseJob(Base):
|
||||||
self.executor = executor
|
self.executor = executor
|
||||||
self.executor_class = executor.__class__.__name__
|
self.executor_class = executor.__class__.__name__
|
||||||
self.start_date = datetime.now()
|
self.start_date = datetime.now()
|
||||||
self.latest_heartbeat = None
|
self.latest_heartbeat = datetime.now()
|
||||||
self.heartrate = heartrate
|
self.heartrate = heartrate
|
||||||
self.unixname = getpass.getuser()
|
self.unixname = getpass.getuser()
|
||||||
super(BaseJob, self).__init__(*args, **kwargs)
|
super(BaseJob, self).__init__(*args, **kwargs)
|
||||||
|
@ -68,6 +69,25 @@ class BaseJob(Base):
|
||||||
(conf.getint('misc', 'JOB_HEARTBEAT_SEC') * 2.1)
|
(conf.getint('misc', 'JOB_HEARTBEAT_SEC') * 2.1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def kill(self):
|
||||||
|
session = settings.Session()
|
||||||
|
job = session.query(BaseJob).filter(BaseJob.id==self.id).first()
|
||||||
|
job.state = State.FAILED
|
||||||
|
job.end_date = datetime.now()
|
||||||
|
try:
|
||||||
|
self.on_kill()
|
||||||
|
except:
|
||||||
|
logging.error('on_kill() method failed')
|
||||||
|
session.merge(job)
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def on_kill(self):
|
||||||
|
'''
|
||||||
|
Will be called when an external kill command is received from the db
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
def heartbeat(self):
|
def heartbeat(self):
|
||||||
'''
|
'''
|
||||||
Heartbeats update the job's entry in the the database with a timestamp
|
Heartbeats update the job's entry in the the database with a timestamp
|
||||||
|
@ -87,14 +107,12 @@ class BaseJob(Base):
|
||||||
heart rate. If you go over 60 seconds before calling it, it won't
|
heart rate. If you go over 60 seconds before calling it, it won't
|
||||||
sleep at all.
|
sleep at all.
|
||||||
'''
|
'''
|
||||||
failed = False
|
|
||||||
session = settings.Session()
|
session = settings.Session()
|
||||||
job = session.query(BaseJob).filter(BaseJob.id==self.id).first()
|
job = session.query(BaseJob).filter(BaseJob.id==self.id).first()
|
||||||
|
|
||||||
if job.state == State.SHUTDOWN:
|
if job.state == State.SHUTDOWN:
|
||||||
job.state = State.FAILED
|
self.kill()
|
||||||
job.end_date = datetime.now()
|
raise Exception("Task shut down externally")
|
||||||
failed = True
|
|
||||||
|
|
||||||
if job.latest_heartbeat:
|
if job.latest_heartbeat:
|
||||||
sleep_for = self.heartrate - (
|
sleep_for = self.heartrate - (
|
||||||
|
@ -107,9 +125,7 @@ class BaseJob(Base):
|
||||||
session.merge(job)
|
session.merge(job)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
logging.info('[heart] Boom.')
|
||||||
if failed:
|
|
||||||
raise Exception("Task shut down externally")
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
# Adding an entry in the DB
|
# Adding an entry in the DB
|
||||||
|
@ -122,6 +138,13 @@ class BaseJob(Base):
|
||||||
self.id = id_
|
self.id = id_
|
||||||
|
|
||||||
# Run!
|
# Run!
|
||||||
|
# This is enough to fail the task instance
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
logging.error("SIGINT (ctrl-c) received")
|
||||||
|
self.kill()
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
self._execute()
|
self._execute()
|
||||||
|
|
||||||
# Marking the success in the DB
|
# Marking the success in the DB
|
||||||
|
@ -338,5 +361,43 @@ class BackfillJob(BaseJob):
|
||||||
succeeded.append(key)
|
succeeded.append(key)
|
||||||
del tasks_to_run[key]
|
del tasks_to_run[key]
|
||||||
executor.end()
|
executor.end()
|
||||||
logging.info("Run summary:")
|
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
class LocalTaskJob(BaseJob):
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
'polymorphic_identity': 'LocalTaskJob'
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_instance,
|
||||||
|
ignore_dependencies=False,
|
||||||
|
force=False,
|
||||||
|
mark_success=False,
|
||||||
|
*args, **kwargs):
|
||||||
|
self.task_instance = task_instance
|
||||||
|
self.ignore_dependencies = ignore_dependencies
|
||||||
|
self.force = force
|
||||||
|
self.mark_success = mark_success
|
||||||
|
super(LocalTaskJob, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _execute(self):
|
||||||
|
|
||||||
|
thr = threading.Thread(
|
||||||
|
target=self.task_instance.run,
|
||||||
|
kwargs={
|
||||||
|
'ignore_dependencies': self.ignore_dependencies,
|
||||||
|
'force': self.force,
|
||||||
|
'mark_success': self.mark_success,
|
||||||
|
})
|
||||||
|
self.thr = thr
|
||||||
|
|
||||||
|
thr.start()
|
||||||
|
while thr.is_alive():
|
||||||
|
self.heartbeat()
|
||||||
|
|
||||||
|
def on_kill(self):
|
||||||
|
self.task_instance.error(self.task_instance.execution_date)
|
||||||
|
self.thr.kill()
|
||||||
|
|
|
@ -256,18 +256,19 @@ class TaskInstance(Base):
|
||||||
mark_success=False,
|
mark_success=False,
|
||||||
ignore_dependencies=False,
|
ignore_dependencies=False,
|
||||||
force=False,
|
force=False,
|
||||||
|
local=True,
|
||||||
pickle_id=None):
|
pickle_id=None):
|
||||||
"""
|
"""
|
||||||
Returns a command that can be executed anywhere where airflow is
|
Returns a command that can be executed anywhere where airflow is
|
||||||
installed. This command is part of the message sent to executors by
|
installed. This command is part of the message sent to executors by
|
||||||
the orchestrator.
|
the orchestrator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
iso = self.execution_date.isoformat()
|
iso = self.execution_date.isoformat()
|
||||||
mark_success = "--mark_success" if mark_success else ""
|
mark_success = "--mark_success" if mark_success else ""
|
||||||
pickle = "--pickle {0}".format(pickle_id) if pickle_id else ""
|
pickle = "--pickle {0}".format(pickle_id) if pickle_id else ""
|
||||||
ignore_dependencies = "-i" if ignore_dependencies else ""
|
ignore_dependencies = "-i" if ignore_dependencies else ""
|
||||||
force = "--force" if force else ""
|
force = "--force" if force else ""
|
||||||
|
local = "--local" if local else ""
|
||||||
subdir = ""
|
subdir = ""
|
||||||
if not pickle and self.task.dag and self.task.dag.full_filepath:
|
if not pickle and self.task.dag and self.task.dag.full_filepath:
|
||||||
subdir = "-sd {0}".format(self.task.dag.full_filepath)
|
subdir = "-sd {0}".format(self.task.dag.full_filepath)
|
||||||
|
@ -276,6 +277,7 @@ class TaskInstance(Base):
|
||||||
"{self.dag_id} {self.task_id} {iso} "
|
"{self.dag_id} {self.task_id} {iso} "
|
||||||
"{mark_success} "
|
"{mark_success} "
|
||||||
"{pickle} "
|
"{pickle} "
|
||||||
|
"{local} "
|
||||||
"{ignore_dependencies} "
|
"{ignore_dependencies} "
|
||||||
"{force} "
|
"{force} "
|
||||||
"{subdir} "
|
"{subdir} "
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -27,6 +27,7 @@ setup(
|
||||||
'mysql-python>=1.2.5',
|
'mysql-python>=1.2.5',
|
||||||
'pandas>=0.15.2',
|
'pandas>=0.15.2',
|
||||||
'pygments>=2.0.1',
|
'pygments>=2.0.1',
|
||||||
|
'pysmbclient>=0.1.3'
|
||||||
'pyhive>=0.1.3',
|
'pyhive>=0.1.3',
|
||||||
'python-dateutil>=2.3',
|
'python-dateutil>=2.3',
|
||||||
'requests>=2.5.1',
|
'requests>=2.5.1',
|
||||||
|
|
Загрузка…
Ссылка в новой задаче