Merge pull request #96 from mistercrunch/async_ti_job

New RunTaskJob that runs async, implements kill signal
This commit is contained in:
Maxime Beauchemin 2015-01-21 22:13:01 -08:00
Родитель c67bb2245e 4a12bd5838
Коммит 689853a797
5 изменённых файлов: 96 добавлений и 27 удалений

Просмотреть файл

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
from airflow.configuration import conf from airflow.configuration import conf
from airflow.executors import DEFAULT_EXECUTOR
from airflow import settings from airflow import settings
from airflow import utils from airflow import utils
from airflow import jobs from airflow import jobs
@ -10,7 +11,6 @@ import dateutil.parser
from datetime import datetime from datetime import datetime
import logging import logging
import os import os
import signal
import sys import sys
import argparse import argparse
@ -67,7 +67,6 @@ def run(args):
logging.basicConfig( logging.basicConfig(
filename=filename, level=logging.INFO, filename=filename, level=logging.INFO,
format=settings.LOG_FORMAT) format=settings.LOG_FORMAT)
print("Logging into: " + filename)
if not args.pickle: if not args.pickle:
dagbag = DagBag(args.subdir) dagbag = DagBag(args.subdir)
@ -85,17 +84,19 @@ def run(args):
# TODO: add run_local and fire it with the right executor from run # TODO: add run_local and fire it with the right executor from run
ti = TaskInstance(task, args.execution_date) ti = TaskInstance(task, args.execution_date)
# This is enough to fail the task instance if args.local:
def signal_handler(signum, frame): print("Logging into: " + filename)
logging.error("SIGINT (ctrl-c) received".format(args.task_id)) run_job = jobs.LocalTaskJob(
ti.error(args.execution_date) task_instance=ti,
sys.exit() mark_success=args.mark_success,
signal.signal(signal.SIGINT, signal_handler) force=args.force,
ignore_dependencies=args.ignore_dependencies)
ti.run( run_job.run()
mark_success=args.mark_success, else:
force=args.force, executor = DEFAULT_EXECUTOR
ignore_dependencies=args.ignore_dependencies) executor.start()
executor.queue_command(ti.key, ti.command())
executor.end()
def list_dags(args): def list_dags(args):
@ -166,7 +167,6 @@ def webserver(args):
host=args.hostname, port=args.port)) host=args.hostname, port=args.port))
from tornado.httpserver import HTTPServer from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from tornado import log
from tornado.wsgi import WSGIContainer from tornado.wsgi import WSGIContainer
# simple multi-process server # simple multi-process server
@ -349,6 +349,10 @@ if __name__ == '__main__':
"-f", "--force", "-f", "--force",
help="Force a run regardless or previous success", help="Force a run regardless or previous success",
action="store_true") action="store_true")
parser_run.add_argument(
"-l", "--local",
help="Runs the task locally, don't use the executor",
action="store_true")
parser_run.add_argument( parser_run.add_argument(
"-i", "--ignore_dependencies", "-i", "--ignore_dependencies",
help="Ignore upstream and depends_on_past dependencies", help="Ignore upstream and depends_on_past dependencies",

Просмотреть файл

@ -23,7 +23,8 @@ class SequentialExecutor(BaseExecutor):
def heartbeat(self): def heartbeat(self):
for key, command in self.commands_to_run: for key, command in self.commands_to_run:
try: try:
sp = subprocess.Popen(command, shell=True).wait() sp = subprocess.Popen(command, shell=True)
sp.wait()
except Exception as e: except Exception as e:
self.change_state(key, State.FAILED) self.change_state(key, State.FAILED)
raise e raise e
@ -31,4 +32,4 @@ class SequentialExecutor(BaseExecutor):
self.commands_to_run = [] self.commands_to_run = []
def end(self): def end(self):
pass self.heartbeat()

Просмотреть файл

@ -3,10 +3,11 @@ import getpass
import logging import logging
import signal import signal
import sys import sys
import threading
from time import sleep from time import sleep
from sqlalchemy import ( from sqlalchemy import (
Column, Integer, String, DateTime, ForeignKey) Column, Integer, String, DateTime)
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm.session import make_transient from sqlalchemy.orm.session import make_transient
@ -57,7 +58,7 @@ class BaseJob(Base):
self.executor = executor self.executor = executor
self.executor_class = executor.__class__.__name__ self.executor_class = executor.__class__.__name__
self.start_date = datetime.now() self.start_date = datetime.now()
self.latest_heartbeat = None self.latest_heartbeat = datetime.now()
self.heartrate = heartrate self.heartrate = heartrate
self.unixname = getpass.getuser() self.unixname = getpass.getuser()
super(BaseJob, self).__init__(*args, **kwargs) super(BaseJob, self).__init__(*args, **kwargs)
@ -68,6 +69,25 @@ class BaseJob(Base):
(conf.getint('misc', 'JOB_HEARTBEAT_SEC') * 2.1) (conf.getint('misc', 'JOB_HEARTBEAT_SEC') * 2.1)
) )
def kill(self):
session = settings.Session()
job = session.query(BaseJob).filter(BaseJob.id==self.id).first()
job.state = State.FAILED
job.end_date = datetime.now()
try:
self.on_kill()
except:
logging.error('on_kill() method failed')
session.merge(job)
session.commit()
session.close()
def on_kill(self):
'''
Will be called when an external kill command is received from the db
'''
pass
def heartbeat(self): def heartbeat(self):
''' '''
Heartbeats update the job's entry in the the database with a timestamp Heartbeats update the job's entry in the the database with a timestamp
@ -87,14 +107,12 @@ class BaseJob(Base):
heart rate. If you go over 60 seconds before calling it, it won't heart rate. If you go over 60 seconds before calling it, it won't
sleep at all. sleep at all.
''' '''
failed = False
session = settings.Session() session = settings.Session()
job = session.query(BaseJob).filter(BaseJob.id==self.id).first() job = session.query(BaseJob).filter(BaseJob.id==self.id).first()
if job.state == State.SHUTDOWN: if job.state == State.SHUTDOWN:
job.state = State.FAILED self.kill()
job.end_date = datetime.now() raise Exception("Task shut down externally")
failed = True
if job.latest_heartbeat: if job.latest_heartbeat:
sleep_for = self.heartrate - ( sleep_for = self.heartrate - (
@ -107,9 +125,7 @@ class BaseJob(Base):
session.merge(job) session.merge(job)
session.commit() session.commit()
session.close() session.close()
logging.info('[heart] Boom.')
if failed:
raise Exception("Task shut down externally")
def run(self): def run(self):
# Adding an entry in the DB # Adding an entry in the DB
@ -122,6 +138,13 @@ class BaseJob(Base):
self.id = id_ self.id = id_
# Run! # Run!
# This is enough to fail the task instance
def signal_handler(signum, frame):
logging.error("SIGINT (ctrl-c) received")
self.kill()
sys.exit()
signal.signal(signal.SIGINT, signal_handler)
self._execute() self._execute()
# Marking the success in the DB # Marking the success in the DB
@ -338,5 +361,43 @@ class BackfillJob(BaseJob):
succeeded.append(key) succeeded.append(key)
del tasks_to_run[key] del tasks_to_run[key]
executor.end() executor.end()
logging.info("Run summary:")
session.close() session.close()
class LocalTaskJob(BaseJob):
__mapper_args__ = {
'polymorphic_identity': 'LocalTaskJob'
}
def __init__(
self,
task_instance,
ignore_dependencies=False,
force=False,
mark_success=False,
*args, **kwargs):
self.task_instance = task_instance
self.ignore_dependencies = ignore_dependencies
self.force = force
self.mark_success = mark_success
super(LocalTaskJob, self).__init__(*args, **kwargs)
def _execute(self):
thr = threading.Thread(
target=self.task_instance.run,
kwargs={
'ignore_dependencies': self.ignore_dependencies,
'force': self.force,
'mark_success': self.mark_success,
})
self.thr = thr
thr.start()
while thr.is_alive():
self.heartbeat()
def on_kill(self):
self.task_instance.error(self.task_instance.execution_date)
self.thr.kill()

Просмотреть файл

@ -256,18 +256,19 @@ class TaskInstance(Base):
mark_success=False, mark_success=False,
ignore_dependencies=False, ignore_dependencies=False,
force=False, force=False,
local=True,
pickle_id=None): pickle_id=None):
""" """
Returns a command that can be executed anywhere where airflow is Returns a command that can be executed anywhere where airflow is
installed. This command is part of the message sent to executors by installed. This command is part of the message sent to executors by
the orchestrator. the orchestrator.
""" """
iso = self.execution_date.isoformat() iso = self.execution_date.isoformat()
mark_success = "--mark_success" if mark_success else "" mark_success = "--mark_success" if mark_success else ""
pickle = "--pickle {0}".format(pickle_id) if pickle_id else "" pickle = "--pickle {0}".format(pickle_id) if pickle_id else ""
ignore_dependencies = "-i" if ignore_dependencies else "" ignore_dependencies = "-i" if ignore_dependencies else ""
force = "--force" if force else "" force = "--force" if force else ""
local = "--local" if local else ""
subdir = "" subdir = ""
if not pickle and self.task.dag and self.task.dag.full_filepath: if not pickle and self.task.dag and self.task.dag.full_filepath:
subdir = "-sd {0}".format(self.task.dag.full_filepath) subdir = "-sd {0}".format(self.task.dag.full_filepath)
@ -276,6 +277,7 @@ class TaskInstance(Base):
"{self.dag_id} {self.task_id} {iso} " "{self.dag_id} {self.task_id} {iso} "
"{mark_success} " "{mark_success} "
"{pickle} " "{pickle} "
"{local} "
"{ignore_dependencies} " "{ignore_dependencies} "
"{force} " "{force} "
"{subdir} " "{subdir} "

Просмотреть файл

@ -27,6 +27,7 @@ setup(
'mysql-python>=1.2.5', 'mysql-python>=1.2.5',
'pandas>=0.15.2', 'pandas>=0.15.2',
'pygments>=2.0.1', 'pygments>=2.0.1',
'pysmbclient>=0.1.3'
'pyhive>=0.1.3', 'pyhive>=0.1.3',
'python-dateutil>=2.3', 'python-dateutil>=2.3',
'requests>=2.5.1', 'requests>=2.5.1',